Esempio n. 1
0
def test():
    from utils.transforms import RandomSizedCrop, IgnoreLabelClass, ToTensorLabel, NormalizeOwn, ZeroPadding, OneHotEncode, RandomSizedCrop3
    from torchvision.transforms import ToTensor, Compose
    import matplotlib.pyplot as plt

    imgtr = [ToTensor(), NormalizeOwn()]
    # sigmoid
    labtr = [IgnoreLabelClass(), ToTensorLabel(tensor_type=torch.FloatTensor)]
    cotr = [RandomSizedCrop3((512, 512))]

    dataset_dir = '/media/data/seg_dataset'
    trainset = Corrosion(home_dir,
                         dataset_dir,
                         img_transform=Compose(imgtr),
                         label_transform=Compose(labtr),
                         co_transform=Compose(cotr),
                         split=args.split,
                         labeled=True)
    trainloader = DataLoader(trainset_l,
                             batch_size=1,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    for batch_id, (img, mask, _, emask) in enumerate(trainloader):
        img, mask, emask = img.numpy()[0], mask.numpy()[0], emask.numpy()[0]
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        ax1.imshow(img)
        ax2.imshow(mask)
        ax3.imshow(emask)
        plt.show()
Esempio n. 2
0
def main():

    args = parse_args()

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(),NormalizeOwn()]

    # softmax
    labtr = [IgnoreLabelClass(),ToTensorLabel()]
    # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
    # cotr = [RandomSizedCrop((320,320))] # (321,321)
    cotr = [RandomSizedCrop3((320,320))]

    print("dataset_dir: ", args.dataset_dir)
    trainset_l = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                           label_transform=Compose(labtr),co_transform=Compose(cotr),
                           split=args.split,labeled=True)
    trainloader_l = DataLoader(trainset_l,batch_size=args.batch_size,shuffle=True,
                               num_workers=2,drop_last=True)

    if args.mode == 'semi':
        trainset_u = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                               label_transform=Compose(labtr),co_transform=Compose(cotr),
                               split=args.split,labeled=False)
        trainloader_u = DataLoader(trainset_u,batch_size=args.batch_size,shuffle=True,
                                   num_workers=2,drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(),ToTensor()]
        else:
            imgtr = [ZeroPadding(),ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        # cotr = [RandomSizedCrop3((320,320))] # (321,321)
        cotr = [RandomSizedCrop3((320,320))]

    valset = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset,batch_size=1)

    #############
    # GENERATOR #
    #############
    # generator = deeplabv2.ResDeeplab()

    # softmax generator: in_chs=3, out_chs=2
    generator = unet.AttU_Net()
    # model_summary = generator.cuda()

    init_weights(generator,args.init_net)

    if args.init_net != 'unet':
        optimG = optim.SGD(filter(lambda p: p.requires_grad, \
            generator.parameters()),lr=args.g_lr,momentum=0.9,\
            weight_decay=0.0001,nesterov=True)
    else:
        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.5, 0.999])

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    if args.mode != "base":
        # softmax generator
        discriminator = Dis(in_channels=2)
        # model_summary = discriminator.cuda()
        # summary(model_summary, (2, 320, 320))
        if args.d_optim == 'adam':
            optimD = optim.Adam(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr = args.d_lr,weight_decay=0.0001)
        else:
            optimD = optim.SGD(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

        if not args.nogpu:
            discriminator = nn.DataParallel(discriminator).cuda()

    if args.mode == 'base':
        train_base(generator,optimG,trainloader_l,valoader,args)
    elif args.mode == 'adv':
        train_adv(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
    elif args.mode == 'semi':
        train_semi(generator,discriminator,optimG,optimD,trainloader_l,trainloader_u,valoader,args)
    else:
        # train_semir(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
        print("training mode incorrect")
Esempio n. 3
0
def main():
    home_dir = os.path.dirname(os.path.realpath(__file__))
    dataset_dir = '/media/data/seg_dataset/corrosion/'
    test_img_list = os

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_dir",
                        help="A directory containing img (Images) \
                        and cls (GT Segmentation) folder")
    parser.add_argument("snapshot", help="Snapshot with the saved model")
    parser.add_argument("--val_orig",
                        help="Do Inference on original size image.\
                        Otherwise, crop to 321x321 like in training ",
                        action='store_true')
    parser.add_argument("--norm",help="Normalize the test images",\
                        action='store_true')
    args = parser.parse_args()
    # print(args.val_orig, args.norm)
    if args.val_orig:
        img_transform = transforms.Compose([ToTensor()])
        if args.norm:
            img_transform = transforms.Compose(
                [ToTensor(), NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])
        co_transform = transforms.Compose([RandomSizedCrop((321, 321))])

        testset = Corrosion(home_dir, args.dataset_dir,img_transform=img_transform, \
            label_transform = label_transform,co_transform=co_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)
    else:
        img_transform = transforms.Compose([ZeroPadding(), ToTensor()])
        if args.norm:
            img_transform = img_transform = transforms.Compose(
                [ZeroPadding(),
                 ToTensor(),
                 NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])

        testset = Corrosion(home_dir,args.dataset_dir,img_transform=img_transform, \
            label_transform=label_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)

    generator = deeplabv2.ResDeeplab()
    assert (os.path.isfile(args.snapshot))
    snapshot = torch.load(args.snapshot)

    saved_net = {
        k.partition('module.')[2]: v
        for i, (k, v) in enumerate(snapshot['state_dict'].items())
    }
    print('Snapshot Loaded')
    generator.load_state_dict(saved_net)
    generator.eval()
    generator = nn.DataParallel(generator).cuda()
    print('Generator Loaded')
    n_classes = 2

    gts, preds = [], []

    print('Prediction Goint to Start')

    # TODO: Crop out the padding before prediction
    for img_id, (img, gt_mask, _) in enumerate(testloader):
        print("Generating Predictions for Image {}".format(img_id))
        gt_mask = gt_mask.numpy()[0]
        img = Variable(img.cuda())
        out_pred_map = generator(img)

        # Get hard prediction
        soft_pred = out_pred_map.data.cpu().numpy()[0]
        soft_pred = soft_pred[:, :gt_mask.shape[0], :gt_mask.shape[1]]
        hard_pred = np.argmax(soft_pred, axis=0).astype(np.uint8)

        for gt_, pred_ in zip(gt_mask, hard_pred):
            gts.append(gt_)
            preds.append(pred_)
    score, class_iou = scores(gts, preds, n_class=n_classes)

    print("Mean IoU: {}".format(score))
Esempio n. 4
0
def evaluate_discriminator():
    home_dir = os.path.dirname(os.path.realpath(__file__))

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_dir",
                        help="A directory containing img (Images) \
                        and cls (GT Segmentation) folder")
    parser.add_argument("snapshot_g",
                        help="Snapshot with the saved generator model")
    parser.add_argument("snapshot_d",
                        help="Snapshot with the saved discriminator model")
    parser.add_argument("--val_orig",
                        help="Do Inference on original size image.\
                        Otherwise, crop to 320x320 like in training ",
                        action='store_true')
    parser.add_argument("--norm",help="Normalize the test images",\
                        action='store_true')
    args = parser.parse_args()
    # print(args.val_orig, args.norm)
    if args.val_orig:
        img_transform = transforms.Compose([ToTensor()])
        if args.norm:
            img_transform = transforms.Compose(
                [ToTensor(), NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])
        # co_transform = transforms.Compose([RandomSizedCrop((320,320))])
        co_transform = transforms.Compose([ResizedImage3((320, 320))])

        testset = Corrosion(home_dir, args.dataset_dir,img_transform=img_transform, \
            label_transform = label_transform,co_transform=co_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)
    else:
        img_transform = transforms.Compose([ZeroPadding(), ToTensor()])
        if args.norm:
            img_transform = img_transform = transforms.Compose(
                [ZeroPadding(),
                 ToTensor(),
                 NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])

        testset = Corrosion(home_dir,args.dataset_dir,img_transform=img_transform, \
            label_transform=label_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)

    # generator = deeplabv2.ResDeeplab()
    # generatro = fcn.FCN8s_soft()
    generator = unet.AttU_Net()
    print(args.snapshot_g)
    assert (os.path.isfile(args.snapshot_g))
    snapshot_g = torch.load(args.snapshot_g)

    discriminator = Dis(in_channels=2)
    print(args.snapshot_d)
    assert (os.path.isfile(args.snapshot_d))
    snapshot_d = torch.load(args.snapshot_d)

    saved_net = {
        k.partition('module.')[2]: v
        for i, (k, v) in enumerate(snapshot_g['state_dict'].items())
    }
    print('Generator Snapshot Loaded')
    generator.load_state_dict(saved_net)
    generator.eval()
    generator = nn.DataParallel(generator).cuda()
    print('Generator Loaded')

    saved_net_d = {
        k.partition('module.')[2]: v
        for i, (k, v) in enumerate(snapshot_d['state_dict'].items())
    }
    print('Discriminator Snapshot Loaded')
    discriminator.load_state_dict(saved_net_d)
    discriminator.eval()
    discriminator = nn.DataParallel(discriminator).cuda()
    print('discriminator Loaded')
    n_classes = 2

    gts, preds = [], []
    print('Prediction Goint to Start')
    colorize = VOCColorize()
    palette = make_palette(2)
    # print(palette)
    IMG_DIR = osp.join(args.dataset_dir, 'corrosion/JPEGImages')
    # TODO: Crop out the padding before prediction
    for img_id, (img, gt_mask, _, gte_mask, name) in enumerate(testloader):
        print("Generating Predictions for Image {}".format(img_id))
        gt_mask = gt_mask.numpy()[0]
        img = Variable(img.cuda())
        # img.cpu().numpy()[0]
        img_path = osp.join(IMG_DIR, name[0] + '.jpg')
        print(img_path)
        img_array = cv2.imread(img_path)
        img_array = cv2.resize(img_array, (320, 320),
                               interpolation=cv2.INTER_AREA)
        img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
        out_pred_map = generator(img)
        # print(out_pred_map.size())

        # Get hard prediction
        soft_pred = out_pred_map.data.cpu().numpy()[0]
        # print("gen: ", soft_pred.shape)
        # print(soft_pred.shape)
        soft_pred = soft_pred[:, :gt_mask.shape[0], :gt_mask.shape[1]]
        # print("gen: ", soft_pred.shape)
        # print(soft_pred.shape)
        hard_pred = np.argmax(soft_pred, axis=0).astype(np.uint8)
        # print("gen: ", hard_pred.shape)

        # Get discriminator prediction
        dis_conf = discriminator(out_pred_map)
        dis_confsmax = nn.Softmax2d()(dis_conf)
        # print(dis_conf.size())
        dis_soft_pred = dis_confsmax.data.cpu().numpy()[0]
        # dis_soft_pred[dis_soft_pred<=0.2] = 0
        # dis_soft_pred[dis_soft_pred>0.2] = 1
        # print("dis: ", dis_soft_pred.shape)
        dis_hard_pred = np.argmax(dis_soft_pred, axis=0).astype(np.uint8)
        # print("dis: ", dis_hard_pred.shape)
        # dis_pred = dis_pred[:,:gt_mask.shape[0],:gt_mask.shape[1]]
        # print(soft_pred.shape)
        # dis_hard_pred = np.argmax(dis_pred,axis=0).astype(np.uint8)

        # print(hard_pred.shape, name)
        output = np.asarray(hard_pred, dtype=np.int)
        # print("gen: ", output.shape)
        filename = os.path.join('results', '{}.png'.format(name[0]))
        color_file = Image.fromarray(
            colorize(output).transpose(1, 2, 0), 'RGB')
        color_file.save(filename)

        masked_im = Image.fromarray(vis_seg(img_array, output, palette))
        masked_im.save(filename[0:-4] + '_vis.png')

        # discriminator output
        dis_output = np.asarray(dis_hard_pred, dtype=np.int)
        # print("dis: ", dis_output.shape)
        dis_filename = os.path.join('results',
                                    '{}_dis.png'.format(name[0][0:-4]))
        dis_color_file = Image.fromarray(
            colorize(dis_output).transpose(1, 2, 0), 'RGB')
        dis_color_file.save(dis_filename)

        for gt_, pred_ in zip(gt_mask, hard_pred):
            gts.append(gt_)
            preds.append(pred_)
        # input('s')
    score, class_iou = scores(gts, preds, n_class=n_classes)
    print("Mean IoU: {}".format(score))
def main():

    args = parse_args()

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(),NormalizeOwn()]

    labtr = [IgnoreLabelClass(),ToTensorLabel()]
    cotr = [RandomSizedCrop((321,321))]

    trainset_l = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr),co_transform=Compose(cotr),split=args.split,labeled=True)
    trainset_u = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr),co_transform=Compose(cotr),split=args.split,labeled=False)

    trainloader_l = DataLoader(trainset_l,batch_size=args.batch_size,shuffle=True,num_workers=2,drop_last=True)
    trainloader_u = DataLoader(trainset_u,batch_size=args.batch_size,shuffle=True,num_workers=2,drop_last=True)


    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(),ToTensor()]
        else:
            imgtr = [ZeroPadding(),ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        cotr = [RandomSizedCrop((321,321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset,batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    init_weights(generator,args.init_net)

    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    if args.mode != "base":
        discriminator = Dis(in_channels=21)
        if args.d_optim == 'adam':
            optimD = optim.Adam(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr = args.d_lr,weight_decay=0.0001)
        else:
            optimD = optim.SGD(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

        if not args.nogpu:
            discriminator = nn.DataParallel(discriminator).cuda()

    if args.mode == 'base':
        train_base(generator,optimG,trainloader_l,valoader,args)
    elif args.mode == 'adv':
        train_adv(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
    else:
        print("Semi-Supervised training")
        train_semi(generator,discriminator,optimG,optimD,trainloader_l,trainloader_u,valoader,args)
Esempio n. 6
0
def main():
    args = parse_args()

    CUR_DIR = os.getcwd()
    with open(osp.join(CUR_DIR, "utils/config_crf.yaml")) as f:
        CRF_CONFIG = Dict(yaml.safe_load(f))

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(),NormalizeOwn()]

    # softmax
    labtr = [IgnoreLabelClass(),ToTensorLabel()]
    # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
    # cotr = [RandomSizedCrop((320,320))] # (321,321)
    cotr = [RandomSizedCrop3((320,320))]

    print("dataset_dir: ", args.dataset_dir)
    if args.mode == 'semi':
        split_ratio = 0.8
    else:
        split_ratio = 1.0
    trainset_l = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                           label_transform=Compose(labtr),co_transform=Compose(cotr),
                           split=split_ratio,labeled=True)
    trainloader_l = DataLoader(trainset_l,batch_size=args.batch_size,shuffle=True,
                               num_workers=2,drop_last=True)

    if args.mode == 'semi':
        trainset_u = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                               label_transform=Compose(labtr),co_transform=Compose(cotr),
                               split=split_ratio,labeled=False)
        trainloader_u = DataLoader(trainset_u,batch_size=args.batch_size,shuffle=True,
                                   num_workers=2,drop_last=True)

    postprocessor = DenseCRF(
        iter_max=CRF_CONFIG.CRF.ITER_MAX,
        pos_xy_std=CRF_CONFIG.CRF.POS_XY_STD,
        pos_w=CRF_CONFIG.CRF.POS_W,
        bi_xy_std=CRF_CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CRF_CONFIG.CRF.BI_RGB_STD,
        bi_w=CRF_CONFIG.CRF.BI_W,
    )

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(),ToTensor()]
        else:
            imgtr = [ZeroPadding(),ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        # cotr = [RandomSizedCrop3((320,320))] # (321,321)
        cotr = [RandomSizedCrop3((320,320))]

    valset = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset,batch_size=1)

    #############
    # GENERATOR #
    #############
    # generator = deeplabv2.ResDeeplab()

    # softmax generator: in_chs=3, out_chs=2
    generator = unet.AttU_Net()
    # model_summary = generator.cuda()

    init_weights(generator,args.init_net)

    if args.init_net != 'unet':
        optimG = optim.SGD(filter(lambda p: p.requires_grad, \
            generator.parameters()),lr=args.g_lr,momentum=0.9,\
            weight_decay=0.0001,nesterov=True)
    else:
        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.9, 0.999])

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    if args.mode != "base":
        # softmax generator
        discriminator = DisSigmoid(in_channels=2)
        init_weights(generator,args.init_net)
        # model_summary = discriminator.cuda()
        # summary(model_summary, (2, 320, 320))
        if args.d_optim == 'adam':
            optimD = optim.Adam(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),args.d_lr,[0.9,0.999])
                # discriminator.parameters()),[0.9,0.999],lr = args.d_lr,weight_decay=0.0001)
        else:
            optimD = optim.SGD(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.9,nesterov=True)

        if not args.nogpu:
            discriminator = nn.DataParallel(discriminator).cuda()

    if args.mode == 'base':
        train_base(generator,optimG,trainloader_l,valoader,args)
    elif args.mode == 'adv':
        train_adv(generator,discriminator,optimG,optimD,trainloader_l,valoader,postprocessor,args)
    elif args.mode == 'semi':
        train_semi(generator,discriminator,optimG,optimD,trainloader_l,trainloader_u,valoader,args)
    else:
        # train_semir(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
        print("training mode incorrect")
Esempio n. 7
0
def main():

    args = parse_args()

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    if len(args.lr_step) != 0:
        steps = list(map(lambda x: int(x), args.lr_step.split(',')))

    # softmax
    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop4((512, 512))]

    print("dataset_dir: ", args.dataset_dir)

    trainset_l = BoxSet(home_dir,
                        args.dataset_dir,
                        img_transform=Compose(imgtr),
                        label_transform=Compose(labtr),
                        co_transform=Compose(cotr),
                        split=args.split,
                        labeled=True,
                        label_correction=True)
    trainloader_l = DataLoader(trainset_l,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=2,
                               drop_last=True)
    if args.split != 1:
        trainset_u = BoxSet(home_dir,
                            args.dataset_dir,
                            img_transform=Compose(imgtr),
                            label_transform=Compose(labtr),
                            co_transform=Compose(cotr),
                            split=args.split,
                            labeled=False,
                            label_correction=True)
        trainloader_u = DataLoader(trainset_l,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        # softmax
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        # softmax
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [ResizedImage4((512, 512))]

    valset = BoxSet(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = unet.AttU_Net(output_ch=7, Centroids=False)

    if osp.isfile(args.snapshot):
        print("load checkpoint => ", args.snapshot)
        checkpoint = torch.load(args.snapshot)
        generator_dict = generator.state_dict()
        saved_net = {
            k.partition('module.')[2]: v
            for i, (k, v) in enumerate(checkpoint['state_dict'].items())
            if k.partition('module.')[2] in generator_dict
        }
        generator_dict.update(saved_net)
        generator.load_state_dict(saved_net)
    else:
        init_weights(generator, args.init_net)

    if args.init_net != 'unet':
        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.5, 0.999])
    else:

        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.5, 0.999])
        """
        optimG = optim.SGD(filter(lambda p: p.requires_grad, \
            generator.parameters()),lr=args.g_lr,momentum=0.9,\
            weight_decay=0.0001,nesterov=True)
        """
    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    if args.mode == 'base':
        train_base(generator, optimG, trainloader_l, valoader, args)
    elif args.mode == 'label_correction':
        train_box_cluster(generator, steps, optimG, trainloader_l, valoader,
                          args)
    else:
        print("training mode incorrect")
def train_semi(args):
    # TODO: Make it more generic to include for other splits
    args.batch_size = args.batch_size * 2

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)
    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    discriminator = Dis(in_channels=21)
    if args.d_optim == 'adam':
        optimD = optim.Adam(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr = args.d_lr)
    else:
        optimD = optim.SGD(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

    if not args.nogpu:
        discriminator = nn.DataParallel(discriminator).cuda()

    ############
    # TRAINING #
    ############
    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, ohmask) in enumerate(trainloader):
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            itr = len(trainloader) * (epoch - 1) + batch_id
            ## TODO: Extend random interleaving for split of any size
            mid = args.batch_size // 2
            img1, mask1, ohmask1 = img[0:mid, ...], mask[0:mid,
                                                         ...], ohmask[0:mid,
                                                                      ...]
            img2, mask2, ohmask2 = img[mid:, ...], mask[mid:,
                                                        ...], ohmask[mid:, ...]

            # Random Interleaving
            if random.random() < 0.5:
                imgl, maskl, ohmaskl = img1, mask1, ohmask1
                imgu, masku, ohmasku = img2, mask2, ohmask2
            else:
                imgu, masku, ohmasku = img1, mask1, ohmask1
                imgl, maskl, ohmaskl = img2, mask2, ohmask2

            ################################################
            #  Labelled data for Discriminator Training #
            ################################################
            cpmap = generator(Variable(imgl.data, volatile=True))
            cpmap = nn.Softmax2d()(cpmap)

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)).long())
            targetr = Variable(torch.ones((N, H, W)).long())
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmaskl.float()))
            optimD.zero_grad()
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                 targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
            else:
                LDr = nn.NLLLoss2d()(confr, targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff, targetf)
            LDf.backward()

            poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ###########################################
            #  labelled data Generator Training       #
            ###########################################
            optimG.zero_grad()

            cpmap = generator(imgl)
            cpmapsmax = nn.Softmax2d()(cpmap)
            cpmaplsmax = nn.LogSoftmax()(cpmap)

            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax, maskl)
            LGadv = nn.NLLLoss2d()(conff, targetr)

            LGadv_d = LGadv.data[0]
            LGce_d = LGce.data[0]

            LGadv = args.lam_adv * LGadv

            (LGce + LGadv).backward()
            #####################################
            # Use unlabelled data to get L_semi #
            #####################################
            LGsemi_d = 0
            if epoch > args.wait_semi:

                cpmap = generator(imgu)
                softpred = nn.Softmax2d()(cpmap)
                hardpred = torch.max(softpred, 1)[1].squeeze(1)
                conf = nn.Softmax2d()(discriminator(
                    Variable(softpred.data, volatile=True)))

                idx = np.zeros(cpmap.data.cpu().numpy().shape, dtype=np.uint8)
                idx = idx.transpose(0, 2, 3, 1)

                confnp = cpmap[:, 1, ...].data.cpu().numpy()
                hardprednp = hardpred.data.cpu().numpy()
                idx[confnp > args.t_semi] = np.identity(
                    21, dtype=idx.dtype)[hardprednp[confnp > args.t_semi]]

                if np.count_nonzero(idx) != 0:
                    cpmaplsmax = nn.LogSoftmax()(cpmap)
                    idx = Variable(torch.from_numpy(idx).byte().cuda())
                    LGsemi_arr = cpmaplsmax.masked_select(idx)
                    LGsemi = -1 * LGsemi_arr.mean()
                    LGsemi_d = LGsemi.data[0]
                    LGsemi = args.lam_semi * LGsemi
                    LGsemi.backward()
                else:
                    LGsemi_d = 0
                LGseg_d = LGce_d + LGadv_d + LGsemi_d

                del idx
                del conf
                del confnp
                del hardpred
                del softpred
                del hardprednp
                del cpmap
            LGseg_d = LGce_d + LGadv_d + LGsemi_d
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            # Manually free memory! Later, really understand how computation graphs free variables

            print("[{}][{}] LD: {:.4f} LD_fake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
                    .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg_d,LGce_d,LGadv_d,LGsemi_d))

        snapshot(generator, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)
def train_adv(args):
    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    discriminator = Dis(in_channels=21)
    if args.d_optim == 'adam':
        optimD = optim.Adam(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr = args.d_lr)
    else:
        optimD = optim.SGD(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

    if not args.nogpu:
        discriminator = nn.DataParallel(discriminator).cuda()

    #############
    # TRAINING  #
    #############
    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, ohmask) in enumerate(trainloader):
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            itr = len(trainloader) * (epoch - 1) + batch_id
            cpmap = generator(Variable(img.data, volatile=True))
            cpmap = nn.Softmax2d()(cpmap)

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)).long(),
                               requires_grad=False)
            targetr = Variable(torch.ones((N, H, W)).long(),
                               requires_grad=False)
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            ##########################
            # DISCRIMINATOR TRAINING #
            ##########################
            optimD.zero_grad()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmask.float()))
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                 targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
            else:
                LDr = nn.NLLLoss2d()(confr, targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff, targetf)
            LDf.backward()

            poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ######################
            # GENERATOR TRAINING #
            #####################
            optimG.zero_grad()

            cmap = generator(img)
            cpmapsmax = nn.Softmax2d()(cmap)
            cpmaplsmax = nn.LogSoftmax()(cmap)
            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax, mask)
            LGadv = nn.NLLLoss2d()(conff, targetr)
            LGseg = LGce + args.lam_adv * LGadv

            LGseg.backward()
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            print("[{}][{}] LD: {:.4f} LDfake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f}"  \
                    .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg.data[0],LGce.data[0],LGadv.data[0]))
        snapshot(generator, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)
def train_base(args):

    #######################
    # Training Dataloader #
    #######################

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    model = deeplabv2.ResDeeplab()
    init_weights(model, args.init_net)

    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        model.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        model = nn.DataParallel(model).cuda()

    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        model.train()
        for batch_id, (img, mask, _) in enumerate(trainloader):

            if args.nogpu:
                img, mask = Variable(img), Variable(mask)
            else:
                img, mask = Variable(img.cuda()), Variable(mask.cuda())

            itr = len(trainloader) * (epoch - 1) + batch_id
            cprob = model(img)
            cprob = nn.LogSoftmax()(cprob)

            Lseg = nn.NLLLoss2d()(cprob, mask)

            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.zero_grad()

            Lseg.backward()
            optimG.step()

            print("[{}][{}]Loss: {:0.4f}".format(epoch, itr, Lseg.data[0]))

        snapshot(model, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)