def test():
    from utils.transforms import RandomSizedCrop, IgnoreLabelClass, ToTensorLabel, NormalizeOwn, ZeroPadding, OneHotEncode, RandomSizedCrop3
    from torchvision.transforms import ToTensor, Compose
    import matplotlib.pyplot as plt

    imgtr = [ToTensor(), NormalizeOwn()]
    # sigmoid
    labtr = [IgnoreLabelClass(), ToTensorLabel(tensor_type=torch.FloatTensor)]
    cotr = [RandomSizedCrop3((512, 512))]

    dataset_dir = '/media/data/seg_dataset'
    trainset = Corrosion(home_dir,
                         dataset_dir,
                         img_transform=Compose(imgtr),
                         label_transform=Compose(labtr),
                         co_transform=Compose(cotr),
                         split=args.split,
                         labeled=True)
    trainloader = DataLoader(trainset_l,
                             batch_size=1,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    for batch_id, (img, mask, _, emask) in enumerate(trainloader):
        img, mask, emask = img.numpy()[0], mask.numpy()[0], emask.numpy()[0]
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        ax1.imshow(img)
        ax2.imshow(mask)
        ax3.imshow(emask)
        plt.show()
Example #2
0
def main():

    args = parse_args()

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(),NormalizeOwn()]

    # softmax
    labtr = [IgnoreLabelClass(),ToTensorLabel()]
    # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
    # cotr = [RandomSizedCrop((320,320))] # (321,321)
    cotr = [RandomSizedCrop3((320,320))]

    print("dataset_dir: ", args.dataset_dir)
    trainset_l = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                           label_transform=Compose(labtr),co_transform=Compose(cotr),
                           split=args.split,labeled=True)
    trainloader_l = DataLoader(trainset_l,batch_size=args.batch_size,shuffle=True,
                               num_workers=2,drop_last=True)

    if args.mode == 'semi':
        trainset_u = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                               label_transform=Compose(labtr),co_transform=Compose(cotr),
                               split=args.split,labeled=False)
        trainloader_u = DataLoader(trainset_u,batch_size=args.batch_size,shuffle=True,
                                   num_workers=2,drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(),ToTensor()]
        else:
            imgtr = [ZeroPadding(),ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        # cotr = [RandomSizedCrop3((320,320))] # (321,321)
        cotr = [RandomSizedCrop3((320,320))]

    valset = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset,batch_size=1)

    #############
    # GENERATOR #
    #############
    # generator = deeplabv2.ResDeeplab()

    # softmax generator: in_chs=3, out_chs=2
    generator = unet.AttU_Net()
    # model_summary = generator.cuda()

    init_weights(generator,args.init_net)

    if args.init_net != 'unet':
        optimG = optim.SGD(filter(lambda p: p.requires_grad, \
            generator.parameters()),lr=args.g_lr,momentum=0.9,\
            weight_decay=0.0001,nesterov=True)
    else:
        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.5, 0.999])

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    if args.mode != "base":
        # softmax generator
        discriminator = Dis(in_channels=2)
        # model_summary = discriminator.cuda()
        # summary(model_summary, (2, 320, 320))
        if args.d_optim == 'adam':
            optimD = optim.Adam(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr = args.d_lr,weight_decay=0.0001)
        else:
            optimD = optim.SGD(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

        if not args.nogpu:
            discriminator = nn.DataParallel(discriminator).cuda()

    if args.mode == 'base':
        train_base(generator,optimG,trainloader_l,valoader,args)
    elif args.mode == 'adv':
        train_adv(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
    elif args.mode == 'semi':
        train_semi(generator,discriminator,optimG,optimD,trainloader_l,trainloader_u,valoader,args)
    else:
        # train_semir(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
        print("training mode incorrect")
Example #3
0
def validate_sim(net,data_dir,train_s_idx,train_e_idx,val_s_idx,val_e_idx,gpu):
    global GPU
    GPU = gpu
    predictions = None
    gt = None
    val_img_arr = []
    val_cls_arr = []
    train_img_arr = []
    train_cls_arr = []
    for v_idx in np.arange(val_s_idx,val_e_idx+1):
         val_img = Image.open(os.path.join(data_dir,"img",str(v_idx)+'.tif'))
         val_cls = Image.open(os.path.join(data_dir,"cls",str(v_idx)+'.png')).convert('P')
         val_img_arr.append(ToTensorTIF()(val_img))
         val_cls_arr.append(ToTensorLabel()(val_cls))

    for t_idx in np.arange(train_s_idx,train_e_idx+1):
        train_img = Image.open(os.path.join(data_dir,"img",str(t_idx)+'.tif'))
        train_cls = Image.open(os.path.join(data_dir,"cls",str(t_idx)+'.png')).convert('P')
        train_img = ToTensorTIF()(train_img)
        train_cls = ToTensorLabel()(train_cls)

        train_img_arr.append(train_img)
        train_cls_arr.append(train_cls)

    for v_idx in np.arange(val_s_idx,val_e_idx+1):
        v_idx = v_idx - val_s_idx
        val_img = val_img_arr[v_idx]
        val_cls = val_cls_arr[v_idx]

        if val_img.dim() == 2:
            val_img = val_img.unsqueeze(0)
        if val_cls.dim() == 2:
            val_cls = val_cls.unsqueeze(0)

        assert(val_img.dim() == 3)
        assert(val_cls.dim() == 3)
        if gpu:
            val_img = val_img.cuda()
            val_Cls = val_cls.cuda()
        if predictions is None:
            predictions = torch.zeros(((val_e_idx - val_s_idx +1),nclasses,) + val_img[0].size()).byte()
            gt = torch.zeros(((val_e_idx - val_s_idx + 1),) + val_img[0].size()).byte()
        gt[v_idx] = val_cls[0]
        # Select 50 Random Traning examples to propagate the labels from
        train_images = np.random.random_integers(train_s_idx,train_e_idx,50)

        for t_idx in np.arange(train_s_idx,train_e_idx+1):
            t_idx = t_idx - train_s_idx
            train_img = train_img_arr[t_idx]
            train_cls = train_cls_arr[t_idx]

            if train_img.dim() == 2:
                train_img = train_img.unsqueeze(0)
            if train_cls.dim() == 2:
                train_cls = train_cls.unsqueeze(0)
            if gpu:
                train_img = train_img.cuda()
                train_cls = train_cls.cuda()

            assert(train_img.dim() == 3)
            assert(train_cls.dim() == 3)

            out_img,out_cls_oh = transform_vols(train_img,val_img,train_cls,val_cls,net)
            predictions[v_idx] += out_cls_oh

    _,predictions = torch.max(predictions,dim=1)
    score = dice_score(gt.numpy(),predictions.numpy(),nclasses)
    score = np.average(score,axis=0)
    return score
Example #4
0
def main():
    home_dir = os.path.dirname(os.path.realpath(__file__))
    dataset_dir = '/media/data/seg_dataset/corrosion/'
    test_img_list = os

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_dir",
                        help="A directory containing img (Images) \
                        and cls (GT Segmentation) folder")
    parser.add_argument("snapshot", help="Snapshot with the saved model")
    parser.add_argument("--val_orig",
                        help="Do Inference on original size image.\
                        Otherwise, crop to 321x321 like in training ",
                        action='store_true')
    parser.add_argument("--norm",help="Normalize the test images",\
                        action='store_true')
    args = parser.parse_args()
    # print(args.val_orig, args.norm)
    if args.val_orig:
        img_transform = transforms.Compose([ToTensor()])
        if args.norm:
            img_transform = transforms.Compose(
                [ToTensor(), NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])
        co_transform = transforms.Compose([RandomSizedCrop((321, 321))])

        testset = Corrosion(home_dir, args.dataset_dir,img_transform=img_transform, \
            label_transform = label_transform,co_transform=co_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)
    else:
        img_transform = transforms.Compose([ZeroPadding(), ToTensor()])
        if args.norm:
            img_transform = img_transform = transforms.Compose(
                [ZeroPadding(),
                 ToTensor(),
                 NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])

        testset = Corrosion(home_dir,args.dataset_dir,img_transform=img_transform, \
            label_transform=label_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)

    generator = deeplabv2.ResDeeplab()
    assert (os.path.isfile(args.snapshot))
    snapshot = torch.load(args.snapshot)

    saved_net = {
        k.partition('module.')[2]: v
        for i, (k, v) in enumerate(snapshot['state_dict'].items())
    }
    print('Snapshot Loaded')
    generator.load_state_dict(saved_net)
    generator.eval()
    generator = nn.DataParallel(generator).cuda()
    print('Generator Loaded')
    n_classes = 2

    gts, preds = [], []

    print('Prediction Goint to Start')

    # TODO: Crop out the padding before prediction
    for img_id, (img, gt_mask, _) in enumerate(testloader):
        print("Generating Predictions for Image {}".format(img_id))
        gt_mask = gt_mask.numpy()[0]
        img = Variable(img.cuda())
        out_pred_map = generator(img)

        # Get hard prediction
        soft_pred = out_pred_map.data.cpu().numpy()[0]
        soft_pred = soft_pred[:, :gt_mask.shape[0], :gt_mask.shape[1]]
        hard_pred = np.argmax(soft_pred, axis=0).astype(np.uint8)

        for gt_, pred_ in zip(gt_mask, hard_pred):
            gts.append(gt_)
            preds.append(pred_)
    score, class_iou = scores(gts, preds, n_class=n_classes)

    print("Mean IoU: {}".format(score))
Example #5
0
def main():
    args = parse_arguments()
    args_str = str(vars(args))
    global w_init
    w_init = args.weight_init

    # Setup tensorboardX logger
    # logger = tbx.SummaryWriter(log_dir = args.log_dir,comment=args.exp_name)
    # logger.add_text('training details',args_str,0)

    # Setup Visdom Logger
    vis = visdom.Visdom(server=args.visdom_server,
                        port=int(args.visdom_port),
                        env=args.exp_name)
    vis.close(
        win=None)  # Close all existing windows from the current environment

    vis.text(args_str)
    ##############################################
    # Visdom Windows for Transformation Heatmaps #
    ##############################################

    #############################################
    # TRAINING DATASET: GENERIC TRANSFORMATION ##
    #############################################
    img_transform = [ToTensorTIF()]
    label_transform = [ToTensorLabel()]
    if args.dataset == "sim":
        trainset = Sim(args.data_dir_2d,
                       args.train_s_idx,
                       args.train_e_idx,
                       co_transform=Compose([]),
                       img_transform=Compose(img_transform),
                       label_transform=Compose(label_transform))
    if args.dataset == "ibsr":
        trainset = IBSRv1(homedir,
                          args.data_dir_2d,
                          co_transform=Compose([]),
                          img_transform=Compose(img_transform),
                          label_transform=Compose(label_transform))
    if args.dataset == "mrbrains":
        trainset = MRBrainS(homedir,
                            args.data_dir_2d,
                            co_transform=Compose([]),
                            img_transform=Compose(img_transform),
                            label_transform=Compose(label_transform))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True)
    print("Dataset Loaded")

    #####################
    # PARAMETER NETWORK #
    ####################

    # net = UNetV1(args.nker)
    net = UNetSmall(args.nker)
    if torch.cuda.is_available():
        net = nn.DataParallel(net).cuda()
    print("Network Loaded")
    net.apply(weight_init)

    #############
    # OPTIMIZER #
    #############
    opt = optim.Adam(filter(lambda p: p.requires_grad, \
                net.parameters()),lr = args.lr,weight_decay=args.l2_weight)

    ##################################################################
    ### GENERATE A BASE GRID USING IDENTITY AFFINE TRANSFORMATION ####
    ##################################################################
    theta = torch.FloatTensor([1, 0, 0, 0, 1, 0])
    theta = theta.view(2, 3)
    theta = theta.expand(args.batch_size, 2, 3)
    if torch.cuda.is_available():
        theta = theta.cuda()

    theta = Variable(theta)
    basegrid_img = F.affine_grid(theta, torch.Size((args.batch_size, 1, H, W)))
    basegrid_label = F.affine_grid(
        theta, torch.Size((args.batch_size, nclasses, H, W)))

    best_score = [0, 0, 0, 0]

    ##############################################
    # Resume training is args.resume is not None #
    ##############################################
    scheduler = StepLR(opt,
                       step_size=args.step_lr_step_size,
                       gamma=args.step_lr_gamma)
    if args.resume is not None:
        print("Resuming Training from {}".format(args.resume))
        snapshot = torch.load(args.resume)
        args.start_epoch = snapshot['epoch']
        best_score = snapshot['best_score']
        net.load_state_dict(snapshot['net'])
        opt.load_state_dict(snapshot['optimizer'])

    else:
        print("No Checkpoint Found")

    #####################
    # VISDOM LOSS SETUP #
    #####################
    win_loss_total = vis.line(Y=np.empty(1), opts=dict(title='loss_total'))
    win_loss_sim = vis.line(Y=np.empty(1), opts=dict(title='loss_sim'))
    win_loss_reg = vis.line(Y=np.empty(1), opts=dict(title='loss_reg'))
    win_loss_seg = vis.line(Y=np.empty(1), opts=dict(title='loss seg'))

    #####################
    # VISDOM DICE SETUP #
    #####################
    dice_opts_t1 = dict(
        title='dice_t1',
        legend=['0', '1', '2', '3'],
    )
    dice_opts_t1_ir = dict(
        title='dice_t1_ir',
        legend=['0', '1', '2', '3'],
    )
    dice_opts_t2 = dict(
        title='dice_t2',
        legend=['0', '1', '2', '3'],
    )
    # empty_data = np.empty((1,nclasses))
    # empty_data[...] = np.nan
    # win_dice_t1 = vis.line(Y=empty_data,opts=dice_opts_t1)
    # win_dice_t1_ir = vis.line(Y=empty_data,opts=dice_opts_t1_ir)
    # win_dice_t2 = vis.line(Y=empty_data,opts=dice_opts_t2)
    win_dice_t1 = None
    win_dice_t1_ir = None
    win_dice_t2 = None

    #############
    ## TRAINING #
    #############

    ##########
    # DEBUG  #
    #########
    #val_results = validate_mrbrains(net,args.data_dir_3d,[1,2,3],[4])
    ####################################

    for epoch in np.arange(args.start_epoch, args.max_epoch):
        scheduler.step()

        loss_total = []
        loss_sim = []
        loss_reg = []
        loss_seg = []
        steps = []

        for batch_id, ((img1, label1, ohlabel1, fname1),
                       (img2, label2, ohlabel2,
                        fname2)) in enumerate(trainloader):
            if img1 is None or label1 is None or img2 is None or label2 is None or ohlabel1 is None:
                continue
            net.train()
            itr = len(trainloader) * (epoch) + batch_id
            steps.append(itr)
            ####################
            # Debug Snapshot code
            #################
            if torch.cuda.is_available():
                img1, label1, img2, label2,combimg,ohlabel1 = Variable(img1.cuda()),\
                        Variable(label1.cuda()), Variable(img2.cuda()), Variable(label2.cuda()),\
                        Variable(torch.cat((img1,img2),1).cuda()), Variable(ohlabel1.cuda())
            else:
                img1, label1, img2, label2,combimg, ohlabel1 = Variable(img1), Variable(label1),\
                        Variable(img2), Variable(label2), Variable(torch.cat((img1,img2),1)), Variable(ohlabel1)
            # (disp0,disp1,disp2) = net(combimg)
            # disp0 = reshape_transform(disp0)
            # disp1 = reshape_transform(disp1)
            # disp2 = reshape_transform(disp2)
            disp = net(combimg)
            disp = process_disp(disp)

            ##########################
            ## IMAGE TRANSFORMATION ##
            ##########################
            # grid_img0 = basegrid_img + disp0
            # grid_img1 = basegrid_img + disp1 + disp0
            # grid_img2 = basegrid_img + disp2 + disp1 + disp0
            grid_img = basegrid_img + disp

            # img1t0 = F.grid_sample(img1,grid_img0)
            # img1t1 = F.grid_sample(img1,grid_img1)
            # img1t2 = F.grid_sample(img1,grid_img2)
            img1t = F.grid_sample(img1, grid_img)

            if args.similarity == 'cc':
                # Lsim0 = cc(img1t0.data,img2.data)
                # Lsim1 = cc(img1t1.data,img2.data)
                # Lsim2 = cc(img1t2.data,img2.data)
                # Lsim = Lsim0 + Lsim1 + Lsim2
                Lsim = cc(img1t, img2)
            elif args.similarity == 'l2':
                Lsim = nn.MSELoss()(img1t, img2)

            ###########################
            ### LABEL TRANSFORMATION ##
            ###########################
            Lseg = Variable(torch.Tensor([0]), requires_grad=True)
            if torch.cuda.is_available():
                Lseg = Variable(torch.Tensor([0]).cuda(), requires_grad=True)
            if args.lambdaseg != 0:
                grid_label = basegrid_label + disp
                cprob2 = F.grid_sample(ohlabel1.float(), grid_label)
                logcprob2 = nn.LogSoftmax()(cprob2)
                Lseg = nn.NLLLoss()(logcprob2, label2)

            ###################
            ## REGULARIZATON ##
            ###################
            Lreg = Variable(torch.Tensor([0]), requires_grad=True)
            target = torch.zeros(1)
            if torch.cuda.is_available():
                Lreg = Variable(torch.Tensor([0]).cuda(), requires_grad=True)
            if args.lambdareg != 0:
                # disp = disp.view(-1,2,h,w)
                dx = disp[:, 1:, :, :] - disp[:, :-1, :, :]
                dy = disp[:, :, 1:, :] - disp[:, :, :-1, :]
                # Implement L1 penalty for now
                # Try to constrain the second derivative
                d2dx2 = torch.abs(dx[:, 1:, :, :] - dx[:, :-1, :, :])
                d2dy2 = torch.abs(dy[:, :, 1:, :] - dy[:, :, :-1, :])
                d2dxdy = torch.abs(dx[:, :, 1:, :] - dx[:, :, :-1, :])
                d2dydx = torch.abs(dy[:, 1:, :, :] - dy[:, :-1, :, :])

                d2_mean = (torch.mean(d2dx2) + torch.mean(d2dy2) +
                           torch.mean(d2dxdy) + torch.mean(d2dydx)) / 4

                # dx_mean = torch.mean(dx)
                # dy_mean = torch.mean(dy)
                # target = torch.zeros(1)
                # if args.gpu:
                #     target = target.cuda()
                # Lreg = nn.L1Loss()((dx_mean+dy_mean)/2,Variable(target))
                Lreg = d2_mean
            ######################
            ## PARAMETER UPDATE ##
            ######################
            Ltotal = Lsim + args.lambdareg * Lreg + args.lambdaseg * Lseg
            opt.zero_grad()
            # opt,lr=poly_lr_scheduler(opt, args.lr, itr,max_iter=len(trainloader)*args.max_epoch)
            Ltotal.backward()
            opt.step()

            free_vars(img1, img2, label1, label2, combimg, ohlabel1, ohlabel2,
                      target)

            loss_total.append(Ltotal.data[0])
            loss_reg.append(Lreg.data[0])
            loss_sim.append(Lsim.data[0])
            loss_seg.append(Lseg.data[0])

            if itr % args.error_iter == 0:
                print("[{}][{}] Ltotal: {:.6} Lsim: {:.6f} Lseg: {:.6f} Lreg: {:.6f} ".\
                    format(epoch,itr,Ltotal.data[0],Lsim.data[0],args.lambdaseg*(Lseg.data[0]),Lreg.data[0]))

        ############
        # VALIDATE #
        ############
        if args.dataset == "sim":
            score = validate_sim(net, args.data_dir_2d, args.train_s_idx,
                                 args.train_e_idx, args.val_s_idx,
                                 args.val_e_idx, args.gpu)
        if args.dataset == "ibsr":
            score = validate_ibsr(net, args.data_dir_3d, args.train_vols,
                                  args.val_vols, args.img_suffix,
                                  args.cls_suffix, perm, args.gpu)
        if args.dataset == "mrbrains":
            val_results = validate_mrbrains(
                net, args.data_dir_3d, [1, 2, 3],
                [4])  # Makeshift changes! Be very careful
            dice_t1 = val_results['scores']['t1']
            dice_t1_ir = val_results['scores']['t1_ir']
            dice_t2 = val_results['scores']['t2']
            if torch.cuda.is_available():
                dice_t1 = dice_t1.cpu().numpy()
                dice_t1_ir = dice_t1_ir.cpu().numpy()
                dice_t2 = dice_t2.cpu().numpy()
            else:
                dice_t1 = dice_t1.numpy()
                dice_t1_ir = dice_t1_ir.numpy()
                dice_t2 = dice_t2.numpy()
            print("dice_t1: {}".format(dice_t1))
            print("dice_t1_ir: {}".format(dice_t1_ir))
            print("dice_t2: {}".format(dice_t2))
            score = (dice_t1 + dice_t1_ir + dice_t2) / 3
        # ############
        # # SNAPSHOT #
        # ############
        best_score = take_snapshot(
            best_score, score, opt, net, epoch,
            snapshot_path(args.snapshot_dir, args.exp_name))

        ########################
        ########################
        ## VISUALIZATION CODE ##
        ########################
        ########################

        # ##################
        # # VISUALIZE LOSS #
        # ##################
        vis.line(Y=np.array(loss_total),
                 X=np.array(steps),
                 win=win_loss_total,
                 update='append')
        vis.line(Y=np.array(loss_sim),
                 X=np.array(steps),
                 win=win_loss_sim,
                 update='append')
        vis.line(Y=np.array(loss_reg),
                 X=np.array(steps),
                 win=win_loss_reg,
                 update='append')
        vis.line(Y=np.array(loss_seg),
                 X=np.array(steps),
                 win=win_loss_seg,
                 update='append')

        ########################
        # VISUALIZE THE SCORES #
        ########################
        if win_dice_t1 is None:
            win_dice_t1 = vis.line(Y=dice_t1.reshape(1, nclasses),
                                   X=np.array([epoch]),
                                   opts=dict(title='dice_t1'))
        else:
            vis.line(Y=dice_t1.reshape(1, nclasses),
                     X=np.array([epoch]),
                     win=win_dice_t1,
                     update='append')
        if win_dice_t1_ir is None:
            win_dice_t1_ir = vis.line(Y=dice_t1_ir.reshape(1, nclasses),
                                      X=np.array([epoch]),
                                      opts=dict(title='dice_t1_ir'))
        else:
            vis.line(Y=dice_t1_ir.reshape(1, nclasses),
                     X=np.array([epoch]),
                     win=win_dice_t1_ir,
                     update='append')
        if win_dice_t2 is None:
            win_dice_t2 = vis.line(Y=dice_t2.reshape(1, nclasses),
                                   X=np.array([epoch]),
                                   opts=dict(title='dice_t2'))
        else:
            vis.line(Y=dice_t2.reshape(1, nclasses),
                     X=np.array([epoch]),
                     win=win_dice_t2,
                     update='append')

        # #####################################################################################
        # # VISUALIZE THE TRANFORMATION PARAMETER OF CENTRAL SLICE AS HEATMAP FOR TRAIN VOL 0 #
        # #####################################################################################
        field_t1 = val_results['fields']['t1']
        field_t1_ir = val_results['fields']['t1_ir']
        field_t2 = val_results['fields']['t2']
        if torch.cuda.is_available():
            field_t1 = field_t1.cpu()
            field_t1_ir = field_t1_ir.cpu()
            field_t2 = field_t2.cpu()

        transform_x_t1 = field_t1[0, 0, 100, :, :, 0]
        transform_y_t1 = field_t1[0, 0, 100, :, :, 1]
        transform_x_t1_ir = field_t1_ir[0, 0, 100, :, :, 0]
        transform_y_t1_ir = field_t1_ir[0, 0, 100, :, :, 1]
        transform_x_t2 = field_t2[0, 0, 100, :, :, 0]
        transform_y_t2 = field_t2[0, 0, 100, :, :, 1]

        vis.heatmap(X=transform_x_t1,
                    opts=dict(title="t1_0_x_{}".format(epoch)))
        vis.heatmap(X=transform_y_t1,
                    opts=dict(title="t1_0_y_{}".format(epoch)))
        vis.heatmap(X=transform_x_t1_ir,
                    opts=dict(title="t1_ir_0_x_{}".format(epoch)))
        vis.heatmap(X=transform_y_t1_ir,
                    opts=dict(title="t1_ir_0_y_{}".format(epoch)))
        vis.heatmap(X=transform_x_t2,
                    opts=dict(title="t2_0_x_{}".format(epoch)))
        vis.heatmap(X=transform_y_t2,
                    opts=dict(title="t2_0_y_{}".format(epoch)))

        vis.save([args.exp_name])
Example #6
0
def evaluate_discriminator():
    home_dir = os.path.dirname(os.path.realpath(__file__))

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_dir",
                        help="A directory containing img (Images) \
                        and cls (GT Segmentation) folder")
    parser.add_argument("snapshot_g",
                        help="Snapshot with the saved generator model")
    parser.add_argument("snapshot_d",
                        help="Snapshot with the saved discriminator model")
    parser.add_argument("--val_orig",
                        help="Do Inference on original size image.\
                        Otherwise, crop to 320x320 like in training ",
                        action='store_true')
    parser.add_argument("--norm",help="Normalize the test images",\
                        action='store_true')
    args = parser.parse_args()
    # print(args.val_orig, args.norm)
    if args.val_orig:
        img_transform = transforms.Compose([ToTensor()])
        if args.norm:
            img_transform = transforms.Compose(
                [ToTensor(), NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])
        # co_transform = transforms.Compose([RandomSizedCrop((320,320))])
        co_transform = transforms.Compose([ResizedImage3((320, 320))])

        testset = Corrosion(home_dir, args.dataset_dir,img_transform=img_transform, \
            label_transform = label_transform,co_transform=co_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)
    else:
        img_transform = transforms.Compose([ZeroPadding(), ToTensor()])
        if args.norm:
            img_transform = img_transform = transforms.Compose(
                [ZeroPadding(),
                 ToTensor(),
                 NormalizeOwn(dataset='corrosion')])
        label_transform = transforms.Compose(
            [IgnoreLabelClass(), ToTensorLabel()])

        testset = Corrosion(home_dir,args.dataset_dir,img_transform=img_transform, \
            label_transform=label_transform,train_phase=False)
        testloader = DataLoader(testset, batch_size=1)

    # generator = deeplabv2.ResDeeplab()
    # generatro = fcn.FCN8s_soft()
    generator = unet.AttU_Net()
    print(args.snapshot_g)
    assert (os.path.isfile(args.snapshot_g))
    snapshot_g = torch.load(args.snapshot_g)

    discriminator = Dis(in_channels=2)
    print(args.snapshot_d)
    assert (os.path.isfile(args.snapshot_d))
    snapshot_d = torch.load(args.snapshot_d)

    saved_net = {
        k.partition('module.')[2]: v
        for i, (k, v) in enumerate(snapshot_g['state_dict'].items())
    }
    print('Generator Snapshot Loaded')
    generator.load_state_dict(saved_net)
    generator.eval()
    generator = nn.DataParallel(generator).cuda()
    print('Generator Loaded')

    saved_net_d = {
        k.partition('module.')[2]: v
        for i, (k, v) in enumerate(snapshot_d['state_dict'].items())
    }
    print('Discriminator Snapshot Loaded')
    discriminator.load_state_dict(saved_net_d)
    discriminator.eval()
    discriminator = nn.DataParallel(discriminator).cuda()
    print('discriminator Loaded')
    n_classes = 2

    gts, preds = [], []
    print('Prediction Goint to Start')
    colorize = VOCColorize()
    palette = make_palette(2)
    # print(palette)
    IMG_DIR = osp.join(args.dataset_dir, 'corrosion/JPEGImages')
    # TODO: Crop out the padding before prediction
    for img_id, (img, gt_mask, _, gte_mask, name) in enumerate(testloader):
        print("Generating Predictions for Image {}".format(img_id))
        gt_mask = gt_mask.numpy()[0]
        img = Variable(img.cuda())
        # img.cpu().numpy()[0]
        img_path = osp.join(IMG_DIR, name[0] + '.jpg')
        print(img_path)
        img_array = cv2.imread(img_path)
        img_array = cv2.resize(img_array, (320, 320),
                               interpolation=cv2.INTER_AREA)
        img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
        out_pred_map = generator(img)
        # print(out_pred_map.size())

        # Get hard prediction
        soft_pred = out_pred_map.data.cpu().numpy()[0]
        # print("gen: ", soft_pred.shape)
        # print(soft_pred.shape)
        soft_pred = soft_pred[:, :gt_mask.shape[0], :gt_mask.shape[1]]
        # print("gen: ", soft_pred.shape)
        # print(soft_pred.shape)
        hard_pred = np.argmax(soft_pred, axis=0).astype(np.uint8)
        # print("gen: ", hard_pred.shape)

        # Get discriminator prediction
        dis_conf = discriminator(out_pred_map)
        dis_confsmax = nn.Softmax2d()(dis_conf)
        # print(dis_conf.size())
        dis_soft_pred = dis_confsmax.data.cpu().numpy()[0]
        # dis_soft_pred[dis_soft_pred<=0.2] = 0
        # dis_soft_pred[dis_soft_pred>0.2] = 1
        # print("dis: ", dis_soft_pred.shape)
        dis_hard_pred = np.argmax(dis_soft_pred, axis=0).astype(np.uint8)
        # print("dis: ", dis_hard_pred.shape)
        # dis_pred = dis_pred[:,:gt_mask.shape[0],:gt_mask.shape[1]]
        # print(soft_pred.shape)
        # dis_hard_pred = np.argmax(dis_pred,axis=0).astype(np.uint8)

        # print(hard_pred.shape, name)
        output = np.asarray(hard_pred, dtype=np.int)
        # print("gen: ", output.shape)
        filename = os.path.join('results', '{}.png'.format(name[0]))
        color_file = Image.fromarray(
            colorize(output).transpose(1, 2, 0), 'RGB')
        color_file.save(filename)

        masked_im = Image.fromarray(vis_seg(img_array, output, palette))
        masked_im.save(filename[0:-4] + '_vis.png')

        # discriminator output
        dis_output = np.asarray(dis_hard_pred, dtype=np.int)
        # print("dis: ", dis_output.shape)
        dis_filename = os.path.join('results',
                                    '{}_dis.png'.format(name[0][0:-4]))
        dis_color_file = Image.fromarray(
            colorize(dis_output).transpose(1, 2, 0), 'RGB')
        dis_color_file.save(dis_filename)

        for gt_, pred_ in zip(gt_mask, hard_pred):
            gts.append(gt_)
            preds.append(pred_)
        # input('s')
    score, class_iou = scores(gts, preds, n_class=n_classes)
    print("Mean IoU: {}".format(score))
def main():

    args = parse_args()

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(),NormalizeOwn()]

    labtr = [IgnoreLabelClass(),ToTensorLabel()]
    cotr = [RandomSizedCrop((321,321))]

    trainset_l = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr),co_transform=Compose(cotr),split=args.split,labeled=True)
    trainset_u = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr),co_transform=Compose(cotr),split=args.split,labeled=False)

    trainloader_l = DataLoader(trainset_l,batch_size=args.batch_size,shuffle=True,num_workers=2,drop_last=True)
    trainloader_u = DataLoader(trainset_u,batch_size=args.batch_size,shuffle=True,num_workers=2,drop_last=True)


    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(),ToTensor()]
        else:
            imgtr = [ZeroPadding(),ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        cotr = [RandomSizedCrop((321,321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset,batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    init_weights(generator,args.init_net)

    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    if args.mode != "base":
        discriminator = Dis(in_channels=21)
        if args.d_optim == 'adam':
            optimD = optim.Adam(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr = args.d_lr,weight_decay=0.0001)
        else:
            optimD = optim.SGD(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

        if not args.nogpu:
            discriminator = nn.DataParallel(discriminator).cuda()

    if args.mode == 'base':
        train_base(generator,optimG,trainloader_l,valoader,args)
    elif args.mode == 'adv':
        train_adv(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
    else:
        print("Semi-Supervised training")
        train_semi(generator,discriminator,optimG,optimD,trainloader_l,trainloader_u,valoader,args)
def main():
    # args = parse_args()
    random.seed(0)
    torch.manual_seed(0)

    dict = {
        'home_dir': home_dir,
        'pretrain': str(pretrain),
        'use_cuda': use_cuda,
        'mode': mode,
        'g_lr': g_lr,
        'd_lr': d_lr,
        'lam_semi': lam_semi,
        'lam_adv': lam_adv,
        't_semi': t_semi,
        'batch_size': batch_size,
        'wait_semi': wait_semi,
        'd_optim': d_optim,
        'snapshot_dir': snapshot_dir,
        'd_label_smooth': d_label_smooth,
        'val_orig': val_orig,
        'start_epoch': start_epoch,
        'max_epoch': max_epoch
    }

    json_str = json.dumps(dict, indent=4)
    with open(f_path, 'a') as f:
        f.write(json_str)
        f.write('\n')

    # normalize = Normalize(mean=[0.459], std=[0.250])##PVP
    normalize = Normalize(mean=[0.414], std=[0.227])  ##AP
    if use_cuda:
        torch.cuda.manual_seed_all(0)

    crop_size = 112
    imgtr = [
        CenterCrop((crop_size, crop_size)),
        Resize((112, 112)),
        ToTensor(), normalize
    ]
    labtr = [
        CenterCrop((crop_size, crop_size)),
        Resize((112, 112)),
        ToTensorLabel()
    ]

    cotr = []
    trainset_l = Promise_data(home_dir,
                              img_transform=Compose(imgtr),
                              label_transform=Compose(labtr),
                              co_transform=Compose(cotr),
                              labelled=True,
                              valid=False)
    trainloader_l = DataLoader(trainset_l,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=0)

    #########################
    # Validation Dataloader #
    ########################
    if val_orig:  ##使用原图
        if no_norm:
            imgtr = [
                CenterCrop((crop_size, crop_size)),
                Resize((112, 112)),
                ToTensor(), normalize
            ]
        else:
            imgtr = [
                CenterCrop((crop_size, crop_size)),
                Resize((112, 112)),
                ToTensor(),
                NormalizeOwn()
            ]
        labtr = [
            CenterCrop((crop_size, crop_size)),
            Resize((112, 112)),
            ToTensorLabel()
        ]
        cotr = []
    else:
        if no_norm:
            imgtr = [
                CenterCrop((crop_size, crop_size)),
                Resize((112, 112)),
                ToTensor(), normalize
            ]
        else:
            imgtr = [
                CenterCrop((crop_size, crop_size)),
                Resize((112, 112)),
                ToTensor(),
                NormalizeOwn()
            ]
        labtr = [
            CenterCrop((crop_size, crop_size)),
            Resize((112, 112)),
            ToTensorLabel()
        ]
        # cotr = [RandomSizedCrop((112, 112))]
        cotr = []

    # valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
    #     label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)

    valset = Promise_data(home_dir,
                          img_transform=Compose(imgtr),
                          label_transform=Compose(labtr),
                          co_transform=Compose(cotr),
                          labelled=True,
                          valid=True)
    valoader = DataLoader(valset, batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = ResUNet(in_ch=1, out_ch=2, base_ch=64)
    # if pretrain:
    #     pretrained_weight_path = '/data/lc/PROMISE12/model_save/semi_15.pth.tar'
    #     pretrained_dict = torch.load(pretrained_weight_path)
    #     generator.load_state_dict(pretrained_dict['state_dict'])
    #     if len(pretrained_dict) != 0:
    #         print('加载成功')

    # optimG = optim.SGD(filter(lambda p: p.requires_grad, generator.parameters()),lr=g_lr,momentum=0.9,\
    #     weight_decay=0.0001,nesterov=True)
    optimG = optim.Adam(filter(lambda p: p.requires_grad,
                               generator.parameters()),
                        lr=g_lr,
                        weight_decay=0.0001)
    if use_cuda:
        generator = generator.cuda()

    if mode == 'base':
        train_base(generator, optimG, trainloader_l, valoader, f_path)
Example #9
0
def main():
    args = parse_args()

    CUR_DIR = os.getcwd()
    with open(osp.join(CUR_DIR, "utils/config_crf.yaml")) as f:
        CRF_CONFIG = Dict(yaml.safe_load(f))

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(),NormalizeOwn()]

    # softmax
    labtr = [IgnoreLabelClass(),ToTensorLabel()]
    # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
    # cotr = [RandomSizedCrop((320,320))] # (321,321)
    cotr = [RandomSizedCrop3((320,320))]

    print("dataset_dir: ", args.dataset_dir)
    if args.mode == 'semi':
        split_ratio = 0.8
    else:
        split_ratio = 1.0
    trainset_l = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                           label_transform=Compose(labtr),co_transform=Compose(cotr),
                           split=split_ratio,labeled=True)
    trainloader_l = DataLoader(trainset_l,batch_size=args.batch_size,shuffle=True,
                               num_workers=2,drop_last=True)

    if args.mode == 'semi':
        trainset_u = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), 
                               label_transform=Compose(labtr),co_transform=Compose(cotr),
                               split=split_ratio,labeled=False)
        trainloader_u = DataLoader(trainset_u,batch_size=args.batch_size,shuffle=True,
                                   num_workers=2,drop_last=True)

    postprocessor = DenseCRF(
        iter_max=CRF_CONFIG.CRF.ITER_MAX,
        pos_xy_std=CRF_CONFIG.CRF.POS_XY_STD,
        pos_w=CRF_CONFIG.CRF.POS_W,
        bi_xy_std=CRF_CONFIG.CRF.BI_XY_STD,
        bi_rgb_std=CRF_CONFIG.CRF.BI_RGB_STD,
        bi_w=CRF_CONFIG.CRF.BI_W,
    )

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(),ToTensor()]
        else:
            imgtr = [ZeroPadding(),ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(),NormalizeOwn()]
        labtr = [IgnoreLabelClass(),ToTensorLabel()]
        # labtr = [IgnoreLabelClass(),ToTensorLabel(tensor_type=torch.FloatTensor)]
        # cotr = [RandomSizedCrop3((320,320))] # (321,321)
        cotr = [RandomSizedCrop3((320,320))]

    valset = Corrosion(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset,batch_size=1)

    #############
    # GENERATOR #
    #############
    # generator = deeplabv2.ResDeeplab()

    # softmax generator: in_chs=3, out_chs=2
    generator = unet.AttU_Net()
    # model_summary = generator.cuda()

    init_weights(generator,args.init_net)

    if args.init_net != 'unet':
        optimG = optim.SGD(filter(lambda p: p.requires_grad, \
            generator.parameters()),lr=args.g_lr,momentum=0.9,\
            weight_decay=0.0001,nesterov=True)
    else:
        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.9, 0.999])

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    if args.mode != "base":
        # softmax generator
        discriminator = DisSigmoid(in_channels=2)
        init_weights(generator,args.init_net)
        # model_summary = discriminator.cuda()
        # summary(model_summary, (2, 320, 320))
        if args.d_optim == 'adam':
            optimD = optim.Adam(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),args.d_lr,[0.9,0.999])
                # discriminator.parameters()),[0.9,0.999],lr = args.d_lr,weight_decay=0.0001)
        else:
            optimD = optim.SGD(filter(lambda p: p.requires_grad, \
                discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.9,nesterov=True)

        if not args.nogpu:
            discriminator = nn.DataParallel(discriminator).cuda()

    if args.mode == 'base':
        train_base(generator,optimG,trainloader_l,valoader,args)
    elif args.mode == 'adv':
        train_adv(generator,discriminator,optimG,optimD,trainloader_l,valoader,postprocessor,args)
    elif args.mode == 'semi':
        train_semi(generator,discriminator,optimG,optimD,trainloader_l,trainloader_u,valoader,args)
    else:
        # train_semir(generator,discriminator,optimG,optimD,trainloader_l,valoader,args)
        print("training mode incorrect")
Example #10
0
def main():

    args = parse_args()

    random.seed(0)
    torch.manual_seed(0)
    if not args.nogpu:
        torch.cuda.manual_seed_all(0)

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    if len(args.lr_step) != 0:
        steps = list(map(lambda x: int(x), args.lr_step.split(',')))

    # softmax
    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop4((512, 512))]

    print("dataset_dir: ", args.dataset_dir)

    trainset_l = BoxSet(home_dir,
                        args.dataset_dir,
                        img_transform=Compose(imgtr),
                        label_transform=Compose(labtr),
                        co_transform=Compose(cotr),
                        split=args.split,
                        labeled=True,
                        label_correction=True)
    trainloader_l = DataLoader(trainset_l,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=2,
                               drop_last=True)
    if args.split != 1:
        trainset_u = BoxSet(home_dir,
                            args.dataset_dir,
                            img_transform=Compose(imgtr),
                            label_transform=Compose(labtr),
                            co_transform=Compose(cotr),
                            split=args.split,
                            labeled=False,
                            label_correction=True)
        trainloader_u = DataLoader(trainset_l,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        # softmax
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        # softmax
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [ResizedImage4((512, 512))]

    valset = BoxSet(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = unet.AttU_Net(output_ch=7, Centroids=False)

    if osp.isfile(args.snapshot):
        print("load checkpoint => ", args.snapshot)
        checkpoint = torch.load(args.snapshot)
        generator_dict = generator.state_dict()
        saved_net = {
            k.partition('module.')[2]: v
            for i, (k, v) in enumerate(checkpoint['state_dict'].items())
            if k.partition('module.')[2] in generator_dict
        }
        generator_dict.update(saved_net)
        generator.load_state_dict(saved_net)
    else:
        init_weights(generator, args.init_net)

    if args.init_net != 'unet':
        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.5, 0.999])
    else:

        optimG = optim.Adam(filter(lambda p: p.requires_grad, \
            generator.parameters()),args.g_lr, [0.5, 0.999])
        """
        optimG = optim.SGD(filter(lambda p: p.requires_grad, \
            generator.parameters()),lr=args.g_lr,momentum=0.9,\
            weight_decay=0.0001,nesterov=True)
        """
    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    if args.mode == 'base':
        train_base(generator, optimG, trainloader_l, valoader, args)
    elif args.mode == 'label_correction':
        train_box_cluster(generator, steps, optimG, trainloader_l, valoader,
                          args)
    else:
        print("training mode incorrect")
def train_semi(args):
    # TODO: Make it more generic to include for other splits
    args.batch_size = args.batch_size * 2

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)
    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    discriminator = Dis(in_channels=21)
    if args.d_optim == 'adam':
        optimD = optim.Adam(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr = args.d_lr)
    else:
        optimD = optim.SGD(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

    if not args.nogpu:
        discriminator = nn.DataParallel(discriminator).cuda()

    ############
    # TRAINING #
    ############
    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, ohmask) in enumerate(trainloader):
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            itr = len(trainloader) * (epoch - 1) + batch_id
            ## TODO: Extend random interleaving for split of any size
            mid = args.batch_size // 2
            img1, mask1, ohmask1 = img[0:mid, ...], mask[0:mid,
                                                         ...], ohmask[0:mid,
                                                                      ...]
            img2, mask2, ohmask2 = img[mid:, ...], mask[mid:,
                                                        ...], ohmask[mid:, ...]

            # Random Interleaving
            if random.random() < 0.5:
                imgl, maskl, ohmaskl = img1, mask1, ohmask1
                imgu, masku, ohmasku = img2, mask2, ohmask2
            else:
                imgu, masku, ohmasku = img1, mask1, ohmask1
                imgl, maskl, ohmaskl = img2, mask2, ohmask2

            ################################################
            #  Labelled data for Discriminator Training #
            ################################################
            cpmap = generator(Variable(imgl.data, volatile=True))
            cpmap = nn.Softmax2d()(cpmap)

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)).long())
            targetr = Variable(torch.ones((N, H, W)).long())
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmaskl.float()))
            optimD.zero_grad()
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                 targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
            else:
                LDr = nn.NLLLoss2d()(confr, targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff, targetf)
            LDf.backward()

            poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ###########################################
            #  labelled data Generator Training       #
            ###########################################
            optimG.zero_grad()

            cpmap = generator(imgl)
            cpmapsmax = nn.Softmax2d()(cpmap)
            cpmaplsmax = nn.LogSoftmax()(cpmap)

            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax, maskl)
            LGadv = nn.NLLLoss2d()(conff, targetr)

            LGadv_d = LGadv.data[0]
            LGce_d = LGce.data[0]

            LGadv = args.lam_adv * LGadv

            (LGce + LGadv).backward()
            #####################################
            # Use unlabelled data to get L_semi #
            #####################################
            LGsemi_d = 0
            if epoch > args.wait_semi:

                cpmap = generator(imgu)
                softpred = nn.Softmax2d()(cpmap)
                hardpred = torch.max(softpred, 1)[1].squeeze(1)
                conf = nn.Softmax2d()(discriminator(
                    Variable(softpred.data, volatile=True)))

                idx = np.zeros(cpmap.data.cpu().numpy().shape, dtype=np.uint8)
                idx = idx.transpose(0, 2, 3, 1)

                confnp = cpmap[:, 1, ...].data.cpu().numpy()
                hardprednp = hardpred.data.cpu().numpy()
                idx[confnp > args.t_semi] = np.identity(
                    21, dtype=idx.dtype)[hardprednp[confnp > args.t_semi]]

                if np.count_nonzero(idx) != 0:
                    cpmaplsmax = nn.LogSoftmax()(cpmap)
                    idx = Variable(torch.from_numpy(idx).byte().cuda())
                    LGsemi_arr = cpmaplsmax.masked_select(idx)
                    LGsemi = -1 * LGsemi_arr.mean()
                    LGsemi_d = LGsemi.data[0]
                    LGsemi = args.lam_semi * LGsemi
                    LGsemi.backward()
                else:
                    LGsemi_d = 0
                LGseg_d = LGce_d + LGadv_d + LGsemi_d

                del idx
                del conf
                del confnp
                del hardpred
                del softpred
                del hardprednp
                del cpmap
            LGseg_d = LGce_d + LGadv_d + LGsemi_d
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            # Manually free memory! Later, really understand how computation graphs free variables

            print("[{}][{}] LD: {:.4f} LD_fake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
                    .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg_d,LGce_d,LGadv_d,LGsemi_d))

        snapshot(generator, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)
def train_adv(args):
    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    discriminator = Dis(in_channels=21)
    if args.d_optim == 'adam':
        optimD = optim.Adam(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr = args.d_lr)
    else:
        optimD = optim.SGD(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

    if not args.nogpu:
        discriminator = nn.DataParallel(discriminator).cuda()

    #############
    # TRAINING  #
    #############
    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, ohmask) in enumerate(trainloader):
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            itr = len(trainloader) * (epoch - 1) + batch_id
            cpmap = generator(Variable(img.data, volatile=True))
            cpmap = nn.Softmax2d()(cpmap)

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)).long(),
                               requires_grad=False)
            targetr = Variable(torch.ones((N, H, W)).long(),
                               requires_grad=False)
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            ##########################
            # DISCRIMINATOR TRAINING #
            ##########################
            optimD.zero_grad()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmask.float()))
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                 targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
            else:
                LDr = nn.NLLLoss2d()(confr, targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff, targetf)
            LDf.backward()

            poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ######################
            # GENERATOR TRAINING #
            #####################
            optimG.zero_grad()

            cmap = generator(img)
            cpmapsmax = nn.Softmax2d()(cmap)
            cpmaplsmax = nn.LogSoftmax()(cmap)
            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax, mask)
            LGadv = nn.NLLLoss2d()(conff, targetr)
            LGseg = LGce + args.lam_adv * LGadv

            LGseg.backward()
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            print("[{}][{}] LD: {:.4f} LDfake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f}"  \
                    .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg.data[0],LGce.data[0],LGadv.data[0]))
        snapshot(generator, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)
def train_base(args):

    #######################
    # Training Dataloader #
    #######################

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    model = deeplabv2.ResDeeplab()
    init_weights(model, args.init_net)

    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        model.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        model = nn.DataParallel(model).cuda()

    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        model.train()
        for batch_id, (img, mask, _) in enumerate(trainloader):

            if args.nogpu:
                img, mask = Variable(img), Variable(mask)
            else:
                img, mask = Variable(img.cuda()), Variable(mask.cuda())

            itr = len(trainloader) * (epoch - 1) + batch_id
            cprob = model(img)
            cprob = nn.LogSoftmax()(cprob)

            Lseg = nn.NLLLoss2d()(cprob, mask)

            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.zero_grad()

            Lseg.backward()
            optimG.step()

            print("[{}][{}]Loss: {:0.4f}".format(epoch, itr, Lseg.data[0]))

        snapshot(model, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)