Example #1
0
def main(args):
    """Create the model and start the evaluation process."""

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(ListDataSet(args.data_dir,
                                             args.img_list,
                                             args.lbl_list,
                                             crop_size=(1024, 512),
                                             mean=IMG_MEAN,
                                             split=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048),
                         mode='bilinear',
                         align_corners=True)

    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, _, name = batch
            if args.model == 'DeeplabMulti':
                output1, output2 = model(Variable(image).cuda(gpu0))
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG':
                output = model(Variable(image).cuda(gpu0))
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            output.save('%s/%s' % (args.save, name))
            output_col.save('%s/%s_color.png' %
                            (args.save, name.split('.')[0]))
Example #2
0
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=input_size, mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        if args.model == 'DeeplabMulti':
            output1, output2,_,_ = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
Example #3
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'Deeplab':
        model = Res_Deeplab(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)

    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)


    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        if args.model == 'Deeplab':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
Example #4
0
def eval(pth, cityscapes_eval_dir, i_iter):
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if args.model == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes)
        saved_state_dict = torch.load(pth)
    elif args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        saved_state_dict = torch.load(pth)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    cityscapesloader = data.DataLoader(cityscapesDataSet(
        args.cityscapes_data_dir,
        args.cityscapes_data_list,
        crop_size=(1024, 512),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                       batch_size=1,
                                       shuffle=False,
                                       pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048),
                         mode='bilinear',
                         align_corners=True)

    for index, batch in enumerate(cityscapesloader):
        with torch.no_grad():
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, name = batch

            output = model(Variable(image).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            output.save('%s/%s' % (cityscapes_eval_dir, name))
            output_col.save('%s/%s_color.png' %
                            (cityscapes_eval_dir, name.split('.')[0]))

            if i_iter == 0:
                break
Example #5
0
def main():

    args = get_arguments()
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args.cuda_device_id)


    if args.model == 'ResNet':
            model = DeeplabMulti(num_classes=args.num_classes)
    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes)



    device = torch.device("cuda" if not args.cpu else "cpu")

    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.to(device)  
    

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, crop_size=(1024, 512),scale=False, mirror=False, mean=IMG_MEAN),
        batch_size=1, shuffle=False,  pin_memory=True)

    count_class = np.zeros((19, 1))
    class_center_temp = np.zeros((19, 256))

    for index, batch in enumerate(trainloader):
        if index % 100 == 0:
            print( '%d processd' % index)
        images, labels, _, _= batch
        images = images.to(device)
        labels = labels.long().to(device)

        with torch.no_grad():
            feature, _ = model(images)

        class_center,count_class_t = class_center_precal(feature,labels)
        count_class = count_class + count_class_t.numpy()
        class_center_temp += class_center.cpu().data[0].numpy()

    
    count_class[count_class==0] = 1              #in case divide 0 error
    
    class_center = class_center_temp/count_class
    np.save('./source_center.npy',class_center)
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(robotDataSet(args.data_dir,
                                              args.data_list,
                                              crop_size=(960, 1280),
                                              resize_size=(1280, 960),
                                              mean=IMG_MEAN,
                                              scale=False,
                                              mirror=False,
                                              set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(robotDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(960 * scale), round(1280 * scale)),
        resize_size=(round(1280 * scale), round(960 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(960, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(960, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        print(image.shape)

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        #output_batch[score_batch<3.6] = 255  #3.6 = 4*0.9

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo')
            #print(save_path)
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        #model = Res_Deeplab(num_classes=args.num_classes)
        model = DeepLab(backbone='resnet', output_stride=8)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)
    model.eval()

    num_classes = 20
    tp_list = [0] * num_classes
    fp_list = [0] * num_classes
    fn_list = [0] * num_classes
    iou_list = [0] * num_classes

    hist = np.zeros((21, 21))
    group = 1
    scorer = SegScorer(num_classes=21)
    datalayer = SSDatalayer(group)
    cos_similarity_func = nn.CosineSimilarity()
    for count in tqdm(range(1000)):
        dat = datalayer.dequeue()
        ref_img = dat['second_img'][0]  # (3, 457, 500)
        query_img = dat['first_img'][0]  # (3, 375, 500)
        query_label = dat['second_label'][0]  # (1, 375, 500)
        ref_label = dat['first_label'][0]  # (1, 457, 500)
        # query_img = dat['second_img'][0]
        # ref_img = dat['first_img'][0]
        # ref_label = dat['second_label'][0]
        # query_label = dat['first_label'][0]
        deploy_info = dat['deploy_info']
        semantic_label = deploy_info['first_semantic_labels'][0][0] - 1  # 2

        ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(
            ref_label).cuda()
        query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(
            query_label[0, :, :]).cuda()
        #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
        #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

        # ref_img = ref_img*ref_label
        ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
        query_label_var, ref_label_var = Variable(query_label), Variable(
            ref_label)

        ref_img_var = torch.unsqueeze(ref_img_var, dim=0)  # [1, 3, 457, 500]
        ref_label_var = torch.unsqueeze(ref_label_var,
                                        dim=1)  # [1, 1, 457, 500]
        query_img_var = torch.unsqueeze(query_img_var,
                                        dim=0)  # [1, 3, 375, 500]
        query_label_var = torch.unsqueeze(query_label_var,
                                          dim=0)  # [1, 375, 500]

        samples = torch.cat([ref_img_var, query_img_var], 0)
        pred = model(samples, ref_label_var)
        w, h = query_label.size()
        pred = F.upsample(pred, size=(w, h), mode='bilinear')  #[2, 416, 416]
        pred = F.softmax(pred, dim=1).squeeze()
        values, pred = torch.max(pred, dim=0)
        #print(pred.shape)
        pred = pred.data.cpu().numpy().astype(np.int32)  # (333, 500)
        #print(pred.shape)
        org_img = get_org_img(
            query_img.squeeze().cpu().data.numpy())  # 查询集的图片(375, 500, 3)
        #print(org_img.shape)
        img = mask_to_img(pred, org_img)  # (375, 500, 3)mask和原图加权后的彩色图片
        cv2.imwrite('save_bins/que_pred/query_set_1_%d.png' % (count), img)

        query_label = query_label.cpu().numpy().astype(np.int32)  # (333, 500)
        class_ind = int(deploy_info['first_semantic_labels'][0][0]
                        ) - 1  # because class indices from 1 in data layer,0
        scorer.update(pred, query_label, class_ind + 1)
        tp, tn, fp, fn = measure(query_label, pred)
        # iou_img = tp/float(max(tn+fp+fn,1))
        tp_list[class_ind] += tp
        fp_list[class_ind] += fp
        fn_list[class_ind] += fn
        # max in case both pred and label are zero
        iou_list = [
            tp_list[ic] /
            float(max(tp_list[ic] + fp_list[ic] + fn_list[ic], 1))
            for ic in range(num_classes)
        ]

        tmp_pred = pred
        tmp_pred[tmp_pred > 0.5] = class_ind + 1
        tmp_gt_label = query_label
        tmp_gt_label[tmp_gt_label > 0.5] = class_ind + 1

        hist += Metrics.fast_hist(tmp_pred, query_label, 21)

    print("-------------GROUP %d-------------" % (group))
    print(iou_list)
    class_indexes = range(group * 5, (group + 1) * 5)
    print('Mean:', np.mean(np.take(iou_list, class_indexes)))
    '''
    for group in range(2):
        datalayer = SSDatalayer(group+1)
        restore(args, model, group+1)

        for count in tqdm(range(1000)):
            dat = datalayer.dequeue()
            ref_img = dat['second_img'][0]#(3, 457, 500)
            query_img = dat['first_img'][0]#(3, 375, 500)
            query_label = dat['second_label'][0]#(1, 375, 500)
            ref_label = dat['first_label'][0]#(1, 457, 500)
            # query_img = dat['second_img'][0]
            # ref_img = dat['first_img'][0]
            # ref_label = dat['second_label'][0]
            # query_label = dat['first_label'][0]
            deploy_info = dat['deploy_info']
            semantic_label = deploy_info['first_semantic_labels'][0][0] - 1#2

            ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(ref_label).cuda()
            query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(query_label[0,:,:]).cuda()
            #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label)
            #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :])

            # ref_img = ref_img*ref_label
            ref_img_var, query_img_var = Variable(ref_img), Variable(query_img)
            query_label_var, ref_label_var = Variable(query_label), Variable(ref_label)

            ref_img_var = torch.unsqueeze(ref_img_var,dim=0)#[1, 3, 457, 500]
            ref_label_var = torch.unsqueeze(ref_label_var, dim=1)#[1, 1, 457, 500]
            query_img_var = torch.unsqueeze(query_img_var, dim=0)#[1, 3, 375, 500]
            query_label_var = torch.unsqueeze(query_label_var, dim=0)#[1, 375, 500]

            logits  = model(query_img_var, ref_img_var, ref_label_var,ref_label_var)

            # w, h = query_label.size()
            # outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear')
            # out_side = F.softmax(outB_side, dim=1).squeeze()
            # values, pred = torch.max(out_side, dim=0)
            values, pred = model.get_pred(logits, query_img_var)#values[2, 333, 500]
            pred = pred.data.cpu().numpy().astype(np.int32)#(333, 500)

            query_label = query_label.cpu().numpy().astype(np.int32)#(333, 500)
            class_ind = int(deploy_info['first_semantic_labels'][0][0])-1 # because class indices from 1 in data layer,0
            scorer.update(pred, query_label, class_ind+1)
            tp, tn, fp, fn = measure(query_label, pred)
            # iou_img = tp/float(max(tn+fp+fn,1))
            tp_list[class_ind] += tp
            fp_list[class_ind] += fp
            fn_list[class_ind] += fn
            # max in case both pred and label are zero
            iou_list = [tp_list[ic] /
                        float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1))
                        for ic in range(num_classes)]


            tmp_pred = pred
            tmp_pred[tmp_pred>0.5] = class_ind+1
            tmp_gt_label = query_label
            tmp_gt_label[tmp_gt_label>0.5] = class_ind+1

            hist += Metrics.fast_hist(tmp_pred, query_label, 21)


        print("-------------GROUP %d-------------"%(group))
        print(iou_list)
        class_indexes = range(group*5, (group+1)*5)
        print('Mean:', np.mean(np.take(iou_list, class_indexes)))

    print('BMVC IOU', np.mean(np.take(iou_list, range(0,20))))

    miou = Metrics.get_voc_iou(hist)
    print('IOU:', miou, np.mean(miou))
    '''

    binary_hist = np.array((hist[0, 0], hist[0, 1:].sum(), hist[1:, 0].sum(),
                            hist[1:, 1:].sum())).reshape((2, 2))
    bin_iu = np.diag(binary_hist) / (binary_hist.sum(1) + binary_hist.sum(0) -
                                     np.diag(binary_hist))
    print('Bin_iu:', bin_iu)

    scores = scorer.score()
    for k in scores.keys():
        print(k, np.mean(scores[k]), scores[k])
Example #8
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        _, output2, features = model(image)
        output = output2.cpu().data[0].numpy()

        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        features = features.cpu().data[0].numpy()

        #### tsne plot ####
        gt_dir = '/project/AdaptSegNet/data/Cityscapes/data/gtFine/val/'
        gt_file = name[0].split('/')[-1]
        gt_file = gt_file.replace('leftImg8bit', 'gtFine_labelIds')
        
        if 'frankfurt' in gt_file:
            gt_file = gt_dir + 'frankfurt/' + gt_file
        elif 'lindau' in gt_file:
            gt_file = gt_dir + 'lindau/' + gt_file
        elif 'munster' in gt_file:
            gt_file = gt_dir + 'munster/' + gt_file

        """
        json_dir = '/project/AdaptSegNet/dataset/cityscapes_list/'
        with open(join(json_dir, 'info.json'), 'r') as fp:
          info = json.load(fp)
        num_classes = np.int(info['classes'])
        print('Num classes', num_classes)
        name_classes = np.array(info['label'], dtype=np.str)
        mapping = np.array(info['label2train'], dtype=np.int)"""

        labels = np.array(Image.open(gt_file))
        labels = label_mask(labels).astype(np.float64)
        labels = skimage.transform.resize(labels, [65, 129])
        labels = label_mask(labels).astype(np.uint8)

        F, H, W = np.squeeze(features).shape
        features = features.reshape(F, H*W).transpose(1,0)
        labels = labels.reshape(H*W,)
        
        # features = features[1:1000, :]
        # labels = labels[1:1000]
        tsne = TSNE(random_state=RS).fit_transform(features)
        pdb.set_trace()
        scatter_plot(tsne, labels)

        #### tsne plot ####

        pdb.set_trace()
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    tau = torch.ones(1) * args.tau
    tau = tau.cuda(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params, False)
    elif args.model == 'DeepLabVGG':
        model = DeeplabVGG(pretrained=True, num_classes=args.num_classes)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    weak_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Resize(1024),
        transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #         RandomCrop(768)
    ])

    target_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Normalize(mean, std)
        #         transforms.Resize(1024),
        #         transforms.ToTensor(),
        #         RandomCrop(768)
    ])

    label_set = GTA5(
        root=args.data_dir,
        num_cls=19,
        split='all',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size,
        #              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )
    unlabel_set = Cityscapes(
        root=args.data_dir_target,
        split=args.set,
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )

    test_set = Cityscapes(
        root=args.data_dir_target,
        split='val',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                       crop_transform=RandomCrop(768)
    )

    label_loader = data.DataLoader(label_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=False)

    unlabel_loader = data.DataLoader(unlabel_set,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=False)

    test_loader = data.DataLoader(test_set,
                                  batch_size=2,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=False)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    [model, model_D2,
     model_D2], [optimizer, optimizer_D1, optimizer_D2
                 ] = amp.initialize([model, model_D2, model_D2],
                                    [optimizer, optimizer_D1, optimizer_D2],
                                    opt_level="O1",
                                    num_losses=7)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = Interpolate(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = Interpolate(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_test = Interpolate(size=(input_size_target[1],
                                    input_size_target[0]),
                              mode='bilinear',
                              align_corners=True)
    #     interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True)

    normalize_transform = transforms.Compose([
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    # labels for adversarial training
    source_label = 0
    target_label = 1

    max_mIoU = 0

    total_loss_seg_value1 = []
    total_loss_adv_target_value1 = []
    total_loss_D_value1 = []
    total_loss_con_value1 = []

    total_loss_seg_value2 = []
    total_loss_adv_target_value2 = []
    total_loss_D_value2 = []
    total_loss_con_value2 = []

    hist = np.zeros((num_cls, num_cls))

    #     for i_iter in range(args.num_steps):
    for i_iter, (batch, batch_un) in enumerate(
            zip(roundrobin_infinite(label_loader),
                roundrobin_infinite(unlabel_loader))):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_con_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_con_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D1.parameters():
            param.requires_grad = False

        for param in model_D2.parameters():
            param.requires_grad = False

        # train with source

        images, labels = batch
        images_orig = images
        images = transform_batch(images, normalize_transform)
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

#         loss.backward()
        loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
        loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

        # train with target

        images_tar, labels_tar = batch_un
        images_tar_orig = images_tar
        images_tar = transform_batch(images_tar, normalize_transform)
        images_tar = Variable(images_tar).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_tar)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_adv_target1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv_target2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()
        loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
        ) / args.iter_size
        loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
        ) / args.iter_size

        # train with consistency loss
        # unsupervise phase
        policies = RandAugment().get_batch_policy(args.batch_size)
        rand_p1 = np.random.random(size=args.batch_size)
        rand_p2 = np.random.random(size=args.batch_size)
        random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2])

        images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1,
                                      rand_p2, random_dir)

        images_aug_orig = images_aug
        images_aug = transform_batch(images_aug, normalize_transform)
        images_aug = Variable(images_aug).cuda(args.gpu)

        pred_target_aug1, pred_target_aug2 = model(images_aug)
        pred_target_aug1 = interp_target(pred_target_aug1)
        pred_target_aug2 = interp_target(pred_target_aug2)

        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1)
        max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1)

        psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32)
        psuedo_label1_thre = psuedo_label1.copy()
        psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies,
                                             rand_p1, rand_p2, random_dir)
        psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32)
        psuedo_label2_thre = psuedo_label2.copy()
        psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies,
                                             rand_p1, rand_p2, random_dir)

        psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu)
        psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu)

        if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre,
                                  args.gpu)
            loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size
        else:
            loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre,
                                  args.gpu)
            loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size
        else:
            loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2
        # proper normalization
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()

# train D

# bring back requires_grad
        for param in model_D1.parameters():
            param.requires_grad = True

        for param in model_D2.parameters():
            param.requires_grad = True

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        D_out1 = model_D1(F.softmax(pred1, dim=1))
        D_out2 = model_D2(F.softmax(pred2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        # train with target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_con_value1,
                    loss_con_value2))

        total_loss_seg_value1.append(loss_seg_value1)
        total_loss_adv_target_value1.append(loss_adv_target_value1)
        total_loss_D_value1.append(loss_D_value1)
        total_loss_con_value1.append(loss_con_value1)

        total_loss_seg_value2.append(loss_seg_value2)
        total_loss_adv_target_value2.append(loss_adv_target_value2)
        total_loss_D_value2.append(loss_D_value2)
        total_loss_con_value2.append(loss_con_value2)

        hist += fast_hist(
            labels.cpu().numpy().flatten().astype(int),
            torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int),
            num_cls)

        if i_iter % 10 == 0:
            print('({}/{})'.format(i_iter + 1, int(args.num_steps)))
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(np.mean(iu)))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_train_acc = acc_overall
            avg_train_loss_seg1 = np.mean(total_loss_seg_value1)
            avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1)
            avg_train_loss_dis1 = np.mean(total_loss_D_value1)
            avg_train_loss_con1 = np.mean(total_loss_con_value1)
            avg_train_loss_seg2 = np.mean(total_loss_seg_value2)
            avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2)
            avg_train_loss_dis2 = np.mean(total_loss_D_value2)
            avg_train_loss_con2 = np.mean(total_loss_con_value2)

            print('avg_train_acc      :', avg_train_acc)
            print('avg_train_loss_seg1 :', avg_train_loss_seg1)
            print('avg_train_loss_adv1 :', avg_train_loss_adv1)
            print('avg_train_loss_dis1 :', avg_train_loss_dis1)
            print('avg_train_loss_con1 :', avg_train_loss_con1)
            print('avg_train_loss_seg2 :', avg_train_loss_seg2)
            print('avg_train_loss_adv2 :', avg_train_loss_adv2)
            print('avg_train_loss_dis2 :', avg_train_loss_dis2)
            print('avg_train_loss_con2 :', avg_train_loss_con2)

            writer['train'].add_scalar('log/mIoU', mIoU, i_iter)
            writer['train'].add_scalar('log/acc', avg_train_acc, i_iter)
            writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2,
                                       i_iter)

            hist = np.zeros((num_cls, num_cls))
            total_loss_seg_value1 = []
            total_loss_adv_target_value1 = []
            total_loss_D_value1 = []
            total_loss_con_value1 = []
            total_loss_seg_value2 = []
            total_loss_adv_target_value2 = []
            total_loss_D_value2 = []
            total_loss_con_value2 = []

            fig = plt.figure(figsize=(15, 15))

            labels = labels[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(331)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(337)
            images = images_orig[0].cpu().numpy().transpose((1, 2, 0))
            #             images += IMG_MEAN
            ax.imshow(images)
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(334)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            labels_tar = labels_tar[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(332)
            ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L')))
            ax.axis("off")
            ax.set_title('tar_labels')

            ax = fig.add_subplot(338)
            ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('tar_datas')

            _, pred_target2 = torch.max(pred_target2, dim=1)
            pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(335)
            ax.imshow(print_palette(
                Image.fromarray(pred_target2).convert('L')))
            ax.axis("off")
            ax.set_title('tar_predicts')

            print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0],
                  'random_dir', random_dir[0])

            psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(333)
            ax.imshow(
                print_palette(
                    Image.fromarray(psuedo_label2_thre).convert('L')))
            ax.axis("off")
            ax.set_title('psuedo_labels')

            ax = fig.add_subplot(339)
            ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('aug_datas')

            _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1)
            pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(336)
            ax.imshow(
                print_palette(Image.fromarray(pred_target_aug2).convert('L')))
            ax.axis("off")
            ax.set_title('aug_predicts')

            #             plt.show()
            writer['train'].add_figure('image/',
                                       fig,
                                       global_step=i_iter,
                                       close=True)

        if i_iter % 500 == 0:
            loss1 = []
            loss2 = []
            for test_i, batch in enumerate(test_loader):

                images, labels = batch
                images_orig = images
                images = transform_batch(images, normalize_transform)
                images = Variable(images).cuda(args.gpu)

                pred1, pred2 = model(images)
                pred1 = interp_test(pred1)
                pred1 = pred1.detach()
                pred2 = interp_test(pred2)
                pred2 = pred2.detach()

                loss_seg1 = loss_calc(pred1, labels, args.gpu)
                loss_seg2 = loss_calc(pred2, labels, args.gpu)
                loss1.append(loss_seg1.item())
                loss2.append(loss_seg2.item())

                hist += fast_hist(
                    labels.cpu().numpy().flatten().astype(int),
                    torch.argmax(pred2,
                                 dim=1).cpu().numpy().flatten().astype(int),
                    num_cls)

            print('test')
            fig = plt.figure(figsize=(15, 15))
            labels = labels[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(311)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(313)
            ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(312)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            #             plt.show()

            writer['test'].add_figure('test_image/',
                                      fig,
                                      global_step=i_iter,
                                      close=True)

            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(mIoU))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_test_loss1 = np.mean(loss1)
            avg_test_loss2 = np.mean(loss2)
            avg_test_acc = acc_overall
            print('avg_test_loss2 :', avg_test_loss1)
            print('avg_test_loss1 :', avg_test_loss2)
            print('avg_test_acc   :', avg_test_acc)
            writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter)
            writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter)
            writer['test'].add_scalar('log/acc', avg_test_acc, i_iter)
            writer['test'].add_scalar('log/mIoU', mIoU, i_iter)

            hist = np.zeros((num_cls, num_cls))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if max_mIoU < mIoU:
            max_mIoU = mIoU
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
Example #10
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    log_dir = args.save
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    # testloader = data.DataLoader(SyntheticSmokeTrain(args={}, dataset_limit=-1, #args.num_steps * args.iter_size * args.batch_size,
    #                 image_shape=(360,640), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)

    testloader = data.DataLoader(SmokeDataset(image_size=(640, 360),
                                              dataset_mean=IMG_MEAN),
                                 batch_size=1,
                                 shuffle=True,
                                 pin_memory=True)
    # testloader = data.DataLoader(SimpleSmokeTrain(args = {}, image_size=(640,360), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)
    # testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
    # batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)

    count = 0
    iou_sum_fg = 0
    iou_count_fg = 0

    iou_sum_bg = 0
    iou_count_bg = 0

    for index, batch in enumerate(testloader):
        if (index + 1) % 100 == 0:
            print('%d processd' % index)
            # print("Processed {}/{}".format(index, len(testloader)))

        # if count > 5:
        #     break
        image, label, name = batch
        if args.model == 'DeeplabMulti':
            with torch.no_grad():
                output1, output2 = model(Variable(image).cuda(gpu0))
            # print(output1.shape)
            # print(output2.shape)
            output = interp(output2).cpu()
            orig_output = output.detach().clone()
            output = output.data[0].numpy()
            # output = (output > 0.5).astype(np.uint8)*255
            # print(np.all(output==0), np.all(output==255))
            # print(np.min(output), np.max(output))

        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            with torch.no_grad():
                output = model(Variable(image).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        classes_seen = set(output.ravel().tolist())
        # print(classes_seen)
        # print(output.shape, name[0])
        output_col = colorize_mask(output)
        output = Image.fromarray(output)
        # print("name", name)
        name = name[0]
        # name = name[0].split('/')[-1]

        if len(classes_seen) > 1:
            count += 1
            print(classes_seen)
            print(Counter(np.asarray(output).ravel()))
            image = image.squeeze()
            for c in range(3):
                image[c, :, :] += IMG_MEAN[c]
                # image2[c,:,:] += IMG_MEAN[2-c]
            image = (image - image.min()) / (image.max() - image.min())
            image = image[[2, 1, 0], :, :]
            print(image.shape, image.min(), image.max())
            output.save(os.path.join(args.save, name + '.png'))
            output_col.save(os.path.join(args.save, name + '_color.png'))
            # output.save('%s/%s.png' % (args.save, name))
            # output_col.save('%s/%s_color.png' % (args.save, name))#.split('.')[0]))

            output_argmaxs = torch.argmax(orig_output.squeeze(), dim=0)
            mask1 = (output_argmaxs == 0).float() * 255
            label = label.squeeze()

            iou_fg = iou_pytorch(mask1, label)
            print("foreground IoU", iou_fg)
            iou_sum_fg += iou_fg
            iou_count_fg += 1

            mask2 = (output_argmaxs > 0).float() * 255
            label2 = label.max() - label

            iou_bg = iou_pytorch(mask2, label2)
            print("IoU for background: ", iou_bg)
            iou_sum_bg += iou_bg
            iou_count_bg += 1

            writer.add_images(f'input_images',
                              tf.resize(image[[2, 1, 0]], [1080, 1920]),
                              index,
                              dataformats='CHW')

            print("shape of label", label.shape)
            label_reshaped = tf.resize(label.unsqueeze(0),
                                       [1080, 1920]).squeeze()
            print("label reshaped: ", label_reshaped.shape)
            writer.add_images(f'labels',
                              label_reshaped,
                              index,
                              dataformats='HW')
            writer.add_images(
                f'output/1',
                255 - np.asarray(tf.resize(output, [1080, 1920])) * 255,
                index,
                dataformats='HW')
            # writer.add_images(f'output/1',np.asarray(output)*255, index,dataformats='HW')
            # writer.add_images(f'output/2',np.asarray(output_col), index, dataformats='HW')
            writer.add_scalar(f'iou/smoke', iou_fg, index)
            writer.add_scalar(f'iou/background', iou_bg, index)
            writer.add_scalar(f'iou/mean', (iou_bg + iou_fg) / 2, index)
            writer.flush()

    if iou_count_fg > 0:
        print("Mean IoU, foreground: {}".format(iou_sum_fg / iou_count_fg))
        print("Mean IoU, background: {}".format(iou_sum_bg / iou_count_bg))
        print("Mean IoU, averaged over classes: {}".format(
            (iou_sum_fg + iou_sum_bg) / (iou_count_fg + iou_count_bg)))
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    nyu_nyu_dict = {11:255, 13:255, 15:255, 17:255, 19:255, 20:255, 21: 255, 23: 255, 
            24:255, 25:255, 26:255, 27:255, 28:255, 29:255, 31:255, 32:255, 33:255}
    nyu_nyu_map = lambda x: nyu_nyu_dict.get(x+1,x)
    nyu_nyu_map = np.vectorize(nyu_nyu_map)
    args.nyu_nyu_map = nyu_nyu_map
    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict}
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    metrics = StreamSegMetrics(args.num_classes)
    metrics_remap = StreamSegMetrics(args.num_classes)
    ignore_label = 255
    value_scale = 255 
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
    val_transform = transforms.Compose([
	    # et.ExtResize( 512 ),
	    transforms.Crop([args.height+1, args.width+1], crop_type='center', padding=IMG_MEAN, ignore_label=ignore_label),
	    transforms.ToTensor(),
	    transforms.Normalize(mean=IMG_MEAN,
	    	    std=[1, 1, 1]),
	])
    val_dst = NYU(root=args.data_dir, opt=args,
			 split='val', transform=val_transform,
			 imWidth = args.width, imHeight = args.height, phase="TEST",
			 randomize = False)
    print("Dset Length {}".format(len(val_dst)))
    testloader = data.DataLoader(val_dst,
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(args.height+1, args.width+1), mode='bilinear', align_corners=True)
    metrics.reset()
    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd' % index)
        image, targets, name = batch
        image = image.to(device)
        print(index)
        if args.model == 'DeeplabMulti':
            output1, output2 = model(image)
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(image)
            output = interp(output).cpu().data[0].numpy()
        targets = targets.cpu().numpy()
        output = output.transpose(1,2,0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        preds = output[None,:,:]
        #input_ = image.cpu().numpy()[0].transpose(1,2,0) + np.array(IMG_MEAN)
        metrics.update(targets, preds)
        targets = args.nyu_nyu_map(targets)
        preds = args.nyu_nyu_map(preds)
        metrics_remap.update(targets,preds)
        #input_ = Image.fromarray(input_.astype(np.uint8))
        #output_col = colorize_mask(output)
        #output = Image.fromarray(output)
        
        #name = name[0].split('/')[-1]
        #input_.save('%s/%s' % (args.save, name))
        #output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
    print(metrics.get_results())
    print(metrics_remap.get_results())
Example #12
0
def main():
    """Create the model and start the evaluation process."""

    for i in range(26, 126):  #1 126
        model_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}.pth'.format(i *
                                                                        2000)
        save_path = './result/GTA2Cityscapes_{0:d}'.format(i * 2000)
        args = get_arguments()

        gpu0 = args.gpu

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        if args.model == 'ResNet':
            model = DeeplabMulti(num_classes=args.num_classes)
        elif args.model == 'VGG':
            model = DeeplabVGG(num_classes=args.num_classes)

        saved_state_dict = torch.load(model_path)
        model.load_state_dict(saved_state_dict)

        model.eval()
        model.cuda(gpu0)

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        with torch.no_grad():
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processd' % index)
                image, _, name = batch
                _, output2 = model(Variable(image).cuda(gpu0))

                output = interp(output2).cpu().data[0].numpy()

                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                name = name[0].split('/')[-1]
                output.save('%s/%s' % (save_path, name))

                output_col.save('%s/%s_color.png' %
                                (save_path, name.split('.')[0]))

        print(save_path)
Example #13
0
def main():
    """Create the model and start the training."""
    w, h = map(int, args.input_size_source.split(','))
    input_size_source = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        # model.load_state_dict(saved_state_dict)

    elif args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           pretrained=True,
                           vgg16_caffe_path=args.restore_from)

        # saved_state_dict = torch.load(args.restore_from)
        # model.load_state_dict(saved_state_dict)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    model.train()
    model.cuda(args.gpu)
    cudnn.benchmark = True

    #Discrimintator setting
    model_D = FCDiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda(args.gpu)

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # labels for adversarial training
    source_adv_label = 0
    target_adv_label = 1

    #Dataloader
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.translated_data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size_source,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    style_trainloader = data.DataLoader(GTA5DataSet(
        args.stylized_data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_source,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        pin_memory=True)

    style_trainloader_iter = enumerate(style_trainloader)

    if STAGE == 1:
        targetloader = data.DataLoader(cityscapesDataSet(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    else:
        #Dataloader for self-training
        targetloader = data.DataLoader(cityscapesDataSetLabel(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set,
            label_folder='Path to generated pseudo labels'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # load checkpoint
    model, model_D, optimizer, start_iter = load_checkpoint(
        model,
        model_D,
        optimizer,
        filename=args.snapshot_dir + 'checkpoint_' + CHECKPOINT + '.pth.tar')

    for i_iter in range(start_iter, args.num_steps):
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        #train segementation network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        if STAGE == 1:
            if i_iter % 2 == 0:
                _, batch = next(trainloader_iter)
            else:
                _, batch = next(style_trainloader_iter)

        else:
            _, batch = next(trainloader_iter)

        image_source, label, _, _ = batch
        image_source = Variable(image_source).cuda(args.gpu)

        pred_source = model(image_source)
        pred_source = interp(pred_source)

        loss_seg_source = loss_calc(pred_source, label, args.gpu)
        loss_seg_source_value = loss_seg_source.item()
        loss_seg_source.backward()

        if STAGE == 2:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, target_label, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #target segmentation loss
            loss_seg_target = loss_calc(pred_target,
                                        target_label,
                                        gpu=args.gpu)
            loss_seg_target.backward()

        # optimize
        optimizer.step()

        if STAGE == 1:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #output-level adversarial training
            D_output_target = model_D(F.softmax(pred_target))
            loss_adv = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_adv = loss_adv * args.lambda_adv
            loss_adv.backward()

            #train discriminator
            for param in model_D.parameters():
                param.requires_grad = True

            pred_source = pred_source.detach()
            pred_target = pred_target.detach()

            D_output_source = model_D(F.softmax(pred_source))
            D_output_target = model_D(F.softmax(pred_target))

            loss_D_source = bce_loss(
                D_output_source,
                Variable(
                    torch.FloatTensor(D_output_source.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_D_target = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        target_adv_label)).cuda(args.gpu))

            loss_D_source = loss_D_source / 2
            loss_D_target = loss_D_target / 2

            loss_D_source.backward()
            loss_D_target.backward()

            #optimize
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg_source = {2:.5f}'.format(
            i_iter, args.num_steps, loss_seg_source_value))

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            state = {
                'iter': i_iter,
                'model': model.state_dict(),
                'model_D': model_D.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                state,
                osp.join(args.snapshot_dir,
                         'checkpoint_' + str(i_iter) + '.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_D_' + str(i_iter) + '.pth'))

            cityscapes_eval_dir = osp.join(args.cityscapes_eval_dir,
                                           str(i_iter))
            if not os.path.exists(cityscapes_eval_dir):
                os.makedirs(cityscapes_eval_dir)

            eval(osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'),
                 cityscapes_eval_dir, i_iter)

            iou19, iou13, iou = compute_mIoU(cityscapes_eval_dir, i_iter)
            outputfile = open(args.output_file, 'a')
            outputfile.write(
                str(i_iter) + '\t' + str(iou19) + '\t' +
                str(iou.replace('\n', ' ')) + '\n')
            outputfile.close()
def main():
    seed = 1338
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)
    for files in range(int(args.num_steps_stop / args.save_pred_every)):
        print('Step: ', (files + 1) * args.save_pred_every)
        if SOURCE_ONLY:
            saved_state_dict = torch.load('./snapshots/source_only/GTA5_' +
                                          str((files + 1) *
                                              args.save_pred_every) + '.pth')
        else:
            if args.level == 'single-level':
                saved_state_dict = torch.load(
                    './snapshots/single_level/GTA5_' +
                    str((files + 1) * args.save_pred_every) + '.pth')
            elif args.level == 'multi-level':
                saved_state_dict = torch.load('./snapshots/multi_level/GTA5_' +
                                              str((files + 1) *
                                                  args.save_pred_every) +
                                              '.pth')
            else:
                raise NotImplementedError(
                    'level choice {} is not implemented'.format(args.level))
        ### for running different versions of pytorch
        model_dict = model.state_dict()
        saved_state_dict = {
            k: v
            for k, v in saved_state_dict.items() if k in model_dict
        }
        model_dict.update(saved_state_dict)
        ###
        model.load_state_dict(saved_state_dict)

        device = torch.device("cuda" if not args.cpu else "cpu")
        model = model.to(device)
        if args.multi_gpu:
            model = nn.DataParallel(model)

        model.eval()

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, name = batch
            image = image.to(device)

            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
                output = model(image)
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            if SOURCE_ONLY:
                if not os.path.exists(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every))):
                    os.makedirs(
                        os.path.join(
                            args.save, 'source_only', 'step' + str(
                                (files + 1) * args.save_pred_every)))
                output.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every), name))
                output_col.save(
                    os.path.join(
                        args.save, 'source_only', 'step' + str(
                            (files + 1) * args.save_pred_every),
                        name.split('.')[0] + '_color.png'))
            else:
                if args.level == 'single-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'single_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'single_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                elif args.level == 'multi-level':
                    if not os.path.exists(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every))):
                        os.makedirs(
                            os.path.join(
                                args.save, 'multi_level', 'step' + str(
                                    (files + 1) * args.save_pred_every)))
                    output.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every), name))
                    output_col.save(
                        os.path.join(
                            args.save, 'multi_level', 'step' + str(
                                (files + 1) * args.save_pred_every),
                            name.split('.')[0] + '_color.png'))
                else:
                    raise NotImplementedError(
                        'level choice {} is not implemented'.format(
                            args.level))
Example #15
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    confidence_path = os.path.join(args.save, 'submit/confidence')
    label_path = os.path.join(args.save, 'submit/labelTrainIds')
    label_invalid_path = os.path.join(args.save,
                                      'submit/labelTrainIds_invalid')
    for path in [confidence_path, label_path, label_invalid_path]:
        if not os.path.exists(path):
            os.makedirs(path)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(DarkZurichDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(h, w),
                                                   resize_size=(w, h),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(DarkZurichDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(h * scale), round(w * scale)),
        resize_size=(round(w * scale), round(h * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1080, 1920),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1080, 1920), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')
    prior = np.load('./utils/prior_all.npy').transpose(
        (2, 0, 1))[np.newaxis, :, :, :]
    prior = torch.from_numpy(prior)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, name = batch
        image2, _, name2 = batch2

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, args.batchsize * len(testloader)),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))

                heatmap_batch = torch.sum(kl_distance(log_sm(output1),
                                                      sm(output2)),
                                          dim=1)

                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                ratio = 0.95
                output_batch = output_batch.cpu() / 4
                # output_batch = output_batch *(ratio + (1 - ratio) * prior)
                output_batch = output_batch.data.numpy()
                heatmap_batch = heatmap_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        threshold = 0.3274
        for i in range(output_batch.shape[0]):
            output_single = output_batch[i, :, :]
            output_col = colorize_mask(output_single)
            output = Image.fromarray(output_single)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

            # heatmap_tmp = heatmap_batch[i,:,:]/np.max(heatmap_batch[i,:,:])
            # fig = plt.figure()
            # plt.axis('off')
            # heatmap = plt.imshow(heatmap_tmp, cmap='viridis')
            # fig.colorbar(heatmap)
            # fig.savefig('%s/%s_heatmap.png' % (save_path, name_tmp.split('.')[0]))

            if args.set == 'test' or args.set == 'val':
                # label
                output.save('%s/%s' % (label_path, name_tmp))
                # label invalid
                output_single[score_batch[i, :, :] < threshold] = 255
                output = Image.fromarray(output_single)
                output.save('%s/%s' % (label_invalid_path, name_tmp))
                # conficence

                confidence = score_batch[i, :, :] * 65535
                confidence = np.asarray(confidence, dtype=np.uint16)
                print(confidence.min(), confidence.max())
                iio.imwrite('%s/%s' % (confidence_path, name_tmp), confidence)

    return args.save
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict, strict=False)

    print(model)

    device = torch.device("cuda" if not args.cpu else "cpu")
    model = model.to(device)

    model.eval()

    testloader = data.DataLoader(BDDDataSet(args.data_dir,
                                            args.data_list,
                                            crop_size=(960, 540),
                                            mean=IMG_MEAN,
                                            scale=False,
                                            mirror=False,
                                            set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)  # 960 540

    interp = nn.Upsample(size=(720, 1280), mode='bilinear', align_corners=True)

    if args.save_confidence:
        select = open('list.txt', 'w')
        c_list = []

    for index, batch in enumerate(testloader):
        if index % 10 == 0:
            print('%d processd' % index)
        image, _, name = batch
        image = image.to(device)

        output = model(image)

        if args.save_confidence:
            confidence = get_confidence(output)
            confidence = confidence.cpu().item()
            c_list.append([confidence, name])
            name = name[0].split('/')[-1]
            save_path = '%s/%s_c.txt' % (args.save, name.split('.')[0])
            record = open(save_path, 'w')
            record.write('%.5f' % confidence)
            record.close()
        else:
            name = name[0].split('/')[-1]

        output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        output.save('%s/%s' % (args.save, name[:-4] + '.png'))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))

    def takeFirst(elem):
        return elem[0]

    if args.save_confidence:
        c_list.sort(key=takeFirst, reverse=True)
        length = len(c_list)
        for i in range(length // 3):
            print(c_list[i][0])
            print(c_list[i][1])
            select.write(c_list[i][1][0])
            select.write('\n')
        select.close()

    print(args.save)
Example #17
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes,
                             train_bn=False,
                             norm_style='in')
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda()

    testloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                             args.data_list,
                                             crop_size=(640, 1280),
                                             resize_size=(1280, 640),
                                             mean=IMG_MEAN,
                                             scale=False,
                                             mirror=False),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, batch in enumerate(testloader):
        if (index * batchsize) % 100 == 0:
            print('%d processd' % (index * batchsize))
        image, _, _, name = batch
        print(image.shape)

        inputs = Variable(image).cuda()
        if args.model == 'DeeplabMulti':
            output1, output2 = model(inputs)
            output_batch = interp(sm(0.5 * output1 +
                                     output2)).cpu().data.numpy()
            #output1, output2 = model(fliplr(inputs))
            #output2 = fliplr(output2)
            #output_batch += interp(output2).cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            output.save('%s/%s' % (args.save, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (args.save, name_tmp.split('.')[0]))

    return args.save
Example #18
0
def main():
    """Create the model and start the training."""
    global args
    args = get_arguments()
    if args.dist:
        init_dist(args.launcher, backend=args.backend)
    world_size = 1
    rank = 0
    if args.dist:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'Deeplab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from, strict=False)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        model.load_state_dict(new_params)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict, strict=False)
    elif args.model == 'DeeplabVGGBN':
        deeplab_vggbn.BatchNorm = SyncBatchNorm2d
        model = deeplab_vggbn.DeeplabVGGBN(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict, strict=False)
            del saved_state_dict

    model.train()
    model.to(device)
    if args.dist:
        broadcast_params(model)

    if rank == 0:
        print(model)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device)
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D1.train()
    model_D1.to(device)
    if args.dist:
        broadcast_params(model_D1)
    if args.restore_D is not None:
        D_dict = torch.load(args.restore_D)
        model_D1.load_state_dict(D_dict, strict=False)
        del D_dict

    model_D2.train()
    model_D2.to(device)
    if args.dist:
        broadcast_params(model_D2)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_data = GTA5BDDDataSet(args.data_dir,
                                args.data_list,
                                max_iters=args.num_steps * args.iter_size *
                                args.batch_size,
                                crop_size=input_size,
                                scale=args.random_scale,
                                mirror=args.random_mirror,
                                mean=IMG_MEAN)
    train_sampler = None
    if args.dist:
        train_sampler = DistributedSampler(train_data)
    trainloader = data.DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  shuffle=False if train_sampler else True,
                                  num_workers=args.num_workers,
                                  pin_memory=False,
                                  sampler=train_sampler)

    trainloader_iter = enumerate(cycle(trainloader))

    target_data = BDDDataSet(args.data_dir_target,
                             args.data_list_target,
                             max_iters=args.num_steps * args.iter_size *
                             args.batch_size,
                             crop_size=input_size_target,
                             scale=False,
                             mirror=args.random_mirror,
                             mean=IMG_MEAN,
                             set=args.set)
    target_sampler = None
    if args.dist:
        target_sampler = DistributedSampler(target_data)
    targetloader = data.DataLoader(target_data,
                                   batch_size=args.batch_size,
                                   shuffle=False if target_sampler else True,
                                   num_workers=args.num_workers,
                                   pin_memory=False,
                                   sampler=target_sampler)

    targetloader_iter = enumerate(cycle(targetloader))

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    #interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard and rank == 0:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    torch.cuda.empty_cache()
    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()

            images, labels, size, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)
            interp = nn.Upsample(size=(size[1], size[0]),
                                 mode='bilinear',
                                 align_corners=True)

            pred1 = model(images)
            pred1 = interp(pred1)

            loss_seg1 = seg_loss(pred1, labels)

            loss = loss_seg1

            # proper normalization
            loss = loss / args.iter_size / world_size
            loss.backward()
            loss_seg_value1 += loss_seg1.item() / args.iter_size

            _, batch = targetloader_iter.__next__()
            # train with target
            images, _, _ = batch
            images = images.to(device)

            pred_target1 = model(images)
            pred_target1 = interp_target(pred_target1)

            D_out1 = model_D1(F.softmax(pred_target1))
            loss_adv_target1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))

            loss = args.lambda_adv_target1 * loss_adv_target1
            loss = loss / args.iter_size / world_size

            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            D_out1 = model_D1(F.softmax(pred1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            loss_D_value1 += loss_D1.item()

            # train with target
            pred_target1 = pred_target1.detach()
            D_out1 = model_D1(F.softmax(pred_target1))
            loss_D1 = bce_loss(
                D_out1,
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label).to(device))
            loss_D1 = loss_D1 / args.iter_size / 2 / world_size
            loss_D1.backward()
            if args.dist:
                average_gradients(model)
                average_gradients(model_D1)
                average_gradients(model_D2)

            loss_D_value1 += loss_D1.item()

        optimizer.step()
        optimizer_D1.step()

        if rank == 0:
            if args.tensorboard:
                scalar_info = {
                    'loss_seg1': loss_seg_value1,
                    'loss_seg2': loss_seg_value2,
                    'loss_adv_target1': loss_adv_target_value1,
                    'loss_adv_target2': loss_adv_target_value2,
                    'loss_D1': loss_D_value1 * world_size,
                    'loss_D2': loss_D_value2 * world_size,
                }

                if i_iter % 10 == 0:
                    for key, val in scalar_info.items():
                        writer.add_scalar(key, val, i_iter)

            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))

            if i_iter >= args.num_steps_stop - 1:
                print('save model ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
                break

            if i_iter % args.save_pred_every == 0 and i_iter != 0:
                print('taking snapshot ...')
                torch.save(
                    model.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '.pth'))
                torch.save(
                    model_D1.state_dict(),
                    osp.join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_D1.pth'))
    print(args.snapshot_dir)
    if args.tensorboard and rank == 0:
        writer.close()
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    #model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(512, 1024),
                                                   resize_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)
    scale = 0.9
    testloader3 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')

    for index, img_data in enumerate(zip(testloader, testloader2,
                                         testloader3)):
        batch, batch2, batch3 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        #image3, _, _, name3 = batch3

        inputs = image.cuda()
        inputs2 = image2.cuda()
        #inputs3 = Variable(image3).cuda()
        print('\r>>>>Extracting feature...%03d/%03d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = output1, output2
                #output_batch = interp(sm(output1))
                #output_batch = interp(sm(output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = heatmap_output1 + output1, heatmap_output2 + output2
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
                heatmap_batch = torch.sum(kl_distance(log_sm(heatmap_output1),
                                                      sm(heatmap_output2)),
                                          dim=1)
                heatmap_batch = torch.log(
                    1 + 10 * heatmap_batch)  # for visualization
                heatmap_batch = heatmap_batch.cpu().data.numpy()

                #output1, output2 = model(inputs3)
                #output_batch += interp(sm(0.5* output1 + output2)).cpu().data.numpy()
                #output1, output2 = model(fliplr(inputs3))
                #output1, output2 = fliplr(output1), fliplr(output2)
                #output_batch += interp(sm(0.5 * output1 + output2)).cpu().data.numpy()
                #del output1, output2, inputs3
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        scoremap_batch = np.asarray(np.max(output_batch, axis=3))
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        output_iterator = []
        heatmap_iterator = []
        scoremap_iterator = []

        for i in range(output_batch.shape[0]):
            output_iterator.append(output_batch[i, :, :])
            heatmap_iterator.append(heatmap_batch[i, :, :] /
                                    np.max(heatmap_batch[i, :, :]))
            scoremap_iterator.append(1 - scoremap_batch[i, :, :] /
                                     np.max(scoremap_batch[i, :, :]))
            name_tmp = name[i].split('/')[-1]
            name[i] = '%s/%s' % (args.save, name_tmp)
        with Pool(4) as p:
            p.map(save, zip(output_iterator, name))
            p.map(save_heatmap, zip(heatmap_iterator, name))
            p.map(save_scoremap, zip(scoremap_iterator, name))

        del output_batch

    return args.save
Example #20
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = 2
    torch.cuda.manual_seed(1337)
    torch.cuda.set_device(2)

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    for i in range(5):
        if not os.path.exists(args.save + '/' + str(i)):
            os.makedirs(args.save + '/' + str(i))

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    print("begin")

    if args.restore_from[:4] == 'http':
        print("1112222")
        #saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        print("2222222", gpu0)
        # saved_state_dict = torch.load(args.restore_from)
        print(args.restore_from)
        model.load_state_dict(torch.load(args.restore_from))
    model.cuda(gpu0)
    # print(sys.getsizeof(model))
    # model.eval()
    # exit()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        # interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
        interp = Upsample_function

    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')
    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, labels, _, name = batch
            image = Variable(image).cuda(gpu0)
            final = []
            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output1 = F.softmax(output1, 1)
                output2 = F.softmax(output2, 1)
                for i in [0, 3, 7, 10]:
                    final_output = i / 10.0 * output1 + (10.0 -
                                                         i) / 10.0 * output2
                    output = interp(final_output).cpu().data[0].numpy()
                    final.append(output)
                    break
                labels = labels.cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG':
                output = model(Variable(image, volatile=True).cuda(gpu0))
                output = interp(output).cpu().data[0].numpy()

            name = name[0].split('/')[-1]
            # labels_col = colorize_mask(labels)
            # labels_col.save('%s/%s_real.png' % (args.save, name.split('.')[0]))
            for i in range(4):
                output = final[i]
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                output.save('%s/%s/%s' % (args.save, str(i), name))
                output_col.save('%s/%s/%s_color.png' %
                                (args.save, str(i), name.split('.')[0]))
                break
Example #21
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           vgg16_caffe_path='./model/vgg16_init.pth',
                           pretrained=True)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=2048).to(device)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=1024).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels = labels.long().to(device)

        feature, prediction = model(images)
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward()
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, _, _ = batch
        images = images.to(device)

        feature_target, _ = model(images)
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward()
        loss_adv_target_value = loss_adv_target.item()

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    if args.tensorboard:
        writer.close()
Example #22
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    bestIoU = 0
    bestIter = 0

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=256).to(device)
        saved_state_dict = torch.load('./snapshots/BestGTA5_D.pth')
        model_D.load_state_dict(saved_state_dict)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=256).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSetLabel(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    test_interp = nn.Upsample(size=(1024, 2048),
                              mode='bilinear',
                              align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # load calculated  class center for initilization
    class_center_source_ori = np.load('./source_center.npy')
    class_center_source_ori = torch.from_numpy(class_center_source_ori)

    class_center_target_ori = np.load('./target_center.npy')
    class_center_target_ori = torch.from_numpy(class_center_target_ori)

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0
        loss_square_value = 0
        loss_st_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels_s = labels  # copy for center calculation
        labels = labels.long().to(device)

        feature, prediction = model(images)
        feature_s = feature  # copy for center calculation
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward(retain_graph=True)
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, labels_pseudo, _, _ = batch
        labels_t = labels_pseudo  # copy for center calculation
        images = images.to(device)
        labels_pseudo = labels_pseudo.long().to(device)

        feature_target, pred_target = model(images)
        feature_t = feature_target  # copy for center calculation
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward(retain_graph=True)
        loss_adv_target_value = loss_adv_target.item()

        pred_target = interp_target(pred_target)
        loss_st = seg_loss(pred_target, labels_pseudo)
        loss_st.backward(retain_graph=True)
        loss_st_value = loss_st.item()

        # class center alignment begin
        if i_iter > 10000:
            class_center_source = class_center_cal(feature_s, labels_s)
            class_center_target = class_center_cal(feature_t, labels_t)
            class_center_source_ori = class_center_update(
                class_center_source, class_center_source_ori,
                args.lambda_center_update)
            class_center_target_ori = class_center_update(
                class_center_target, class_center_target_ori,
                args.lambda_center_update)

            class_center_source_ori = class_center_source_ori.detach(
            )  #align target center to source

            center_diff = class_center_source_ori - class_center_target_ori
            loss_square = torch.pow(center_diff, 2).sum()

            loss = args.lambda_center * loss_square
            loss.backward()
            loss_square_value = loss_square.item()
        # class center alignment end

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        class_center_target_ori = class_center_target_ori.detach()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_st_value': loss_st_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f} loss_st = {6:.5f} loss_square = {7:.5f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value, loss_st_value,
                    loss_square_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            testloader = data.DataLoader(cityscapesDataSet(
                args.data_dir_target,
                args.data_list_target_test,
                crop_size=(1024, 512),
                mean=IMG_MEAN,
                scale=False,
                mirror=False,
                set='val'),
                                         batch_size=1,
                                         shuffle=False,
                                         pin_memory=True)
            model.eval()
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processd' % index)
                image, _, name = batch
                with torch.no_grad():
                    output1, output2 = model(Variable(image).to(device))
                output = test_interp(output2).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output = Image.fromarray(output)
                name = name[0].split('/')[-1]
                output.save('%s/%s' % (args.save, name))
            mIoUs = compute_mIoU(osp.join(args.data_dir_target, 'gtFine/val'),
                                 args.save, 'dataset/cityscapes_list')
            mIoU = round(np.nanmean(mIoUs) * 100, 2)

            print('===>  current   mIoU: ' + str(mIoU))
            print('===> last best  mIoU: ' + str(bestIoU))
            print('===> last best  iter: ' + str(bestIter))

            if mIoU > bestIoU:
                bestIoU = mIoU
                bestIter = i_iter
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5.pth'))
                torch.save(model_D.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5_D.pth'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
            model.train()

    if args.tensorboard:
        writer.close()