def train_base(generator, optimG, trainloader, valoader, args):

    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, _) in enumerate(trainloader):

            if args.nogpu:
                img, mask = Variable(img), Variable(mask)
            else:
                img, mask = Variable(img.cuda()), Variable(mask.cuda())

            itr = len(trainloader) * (epoch - 1) + batch_id
            cprob = generator(img)
            cprob = nn.LogSoftmax()(cprob)

            Lseg = nn.NLLLoss2d()(cprob, mask)

            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.zero_grad()

            Lseg.backward()
            optimG.step()

            print("[{}][{}]Loss: {:0.4f}".format(epoch, itr, Lseg.data[0]))

        best_miou = snapshot(generator, valoader, epoch, best_miou,
                             args.snapshot_dir, args.prefix)
def train_base(generator, optimG, trainloader, valoader, f_path):
    best_miou_unlabel = -1
    best_miou = -1
    for epoch in range(start_epoch, max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, _, _) in enumerate(trainloader):

            if not use_cuda:
                img, mask = img, mask
            else:
                img, mask = img.cuda(), mask.cuda()

            # print(mask.max())###测试mask的最大值是多少

            itr = len(trainloader) * (epoch - 1) + batch_id
            cprob = generator(img)
            cprob = nn.LogSoftmax()(cprob)
            # print(cprob.shape)
            # print(mask.shape)

            Lseg = nn.NLLLoss()(cprob, mask)

            optimG = poly_lr_scheduler(optimG, g_lr, itr)
            optimG.zero_grad()

            Lseg.backward()
            optimG.step()

            print("[{}][{}]Loss: {:0.4f}".format(epoch, itr, Lseg.item()))

        best_miou_unlabel, best_miou = snapshot(
            generator, optimG, trainloader, valoader, epoch, best_miou_unlabel,
            best_miou, snapshot_dir, prefix + '_' + str(epoch), f_path)
Exemplo n.º 3
0
def train_semi(generator,discriminator,optimG,optimD,trainloader_l,trainloader_u,valoader,args):

    best_miou = -1
    best_eiou = -1
    for epoch in range(args.start_epoch,args.max_epoch+1):
        generator.train()
        trainloader_l_iter = iter(trainloader_l)
        trainloader_u_iter = iter(trainloader_u)
        print("Epoch: {}".format(epoch))
        batch_id = 0
        while(True):
            batch_id += 1
            itr = (len(trainloader_u) + len(trainloader_l))*(epoch-1) + batch_id
            LGsemi_d = 0
            LGsemi_c = 0
            if epoch > args.wait_semi:
                if random.random() <0.5:
                    loader_l = trainloader_l_iter
                    loader_u = trainloader_u_iter
                else:
                    loader_l = trainloader_u_iter
                    loader_u = trainloader_l_iter
                # Check if the loader has a batch available
                try:
                    img_u,mask_u,ohmask_u,_ = next(loader_u)
                except:
                    trainloader_u_iter = iter(trainloader_u)
                    loader_u = trainloader_u_iter
                    img_u,mask_u,ohmask_u,_ = next(loader_u)

                if args.nogpu:
                    img_u,mask_u,ohmask_u = Variable(img_u),Variable(mask_u),Variable(ohmask_u)
                else:
                    img_u,mask_u,ohmask_u = Variable(img_u.cuda()),Variable(mask_u.cuda()),Variable(ohmask_u.cuda())

                # semi unlabelled training
                cpmap = generator(img_u)
                cpmapsmax = nn.Softmax2d()(cpmap)
                conf = discriminator(cpmapsmax)
                confsmax = nn.Softmax2d()(conf)
                conflsmax = nn.LogSoftmax()(conf)

                N = cpmap.size()[0]
                H = cpmap.size()[2]
                W = cpmap.size()[3]
                # Adversarial Loss
                targetr = Variable(torch.ones((N,H,W)).long())
                if not args.nogpu:
                    targetr = targetr.cuda()
                LGadv = nn.NLLLoss2d()(conflsmax,targetr)
                LGadv_d = LGadv.data

                hardpred = torch.max(cpmapsmax,1)[1].squeeze(1)
                idx = np.zeros(cpmap.data.cpu().numpy().shape,dtype=np.uint8)
                idx = idx.transpose(0, 2, 3, 1)
                confnp = confsmax[:,1,...].data.cpu().numpy()
                hardprednp = hardpred.data.cpu().numpy()
                idx[confnp > args.t_semi] = np.identity(2, dtype=idx.dtype)[hardprednp[ confnp > args.t_semi]]
                LG = args.lam_adv*LGadv
                if np.count_nonzero(idx) != 0:
                    cpmaplsmax = nn.LogSoftmax()(cpmap)
                    idx = Variable(torch.from_numpy(idx.transpose(0,3,1,2)).byte().cuda())
                    LGsemi_arr = cpmaplsmax.masked_select(idx.bool())
                    LGsemi = -1*LGsemi_arr.mean()
                    LGsemi_d = LGsemi.data
                    # LGsemi_c = LGsemi.cpu().numpy()
                    LGsemi = args.lam_semi*LGsemi
                    # LG += LGsemi
                LG.backward()
                optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
                optimG.step()
            ################################################
            #  train labelled data                         #
            ################################################
            loader_l = trainloader_l_iter
            loader_u = trainloader_u_iter
            try:
                if random.random() <0.5:
                    img_l,mask_l,ohmask_l,_ = next(loader_l)
                else:
                    img_l,mask_l,ohmask_l,_ = next(loader_u)
            except:
                break
            if args.nogpu:
                img_l,mask_l,ohmask_l = Variable(img_l),Variable(mask_l),Variable(ohmask_l)
            else:
                img_l,mask_l,ohmask_l = Variable(img_l.cuda()),Variable(mask_l.cuda()),Variable(ohmask_l.cuda())
            ################################################
            #  Labelled data for Discriminator Training #
            ################################################
            cpmap = generator(Variable(img_l.data,volatile=True))
            cpmap = nn.Softmax2d()(cpmap)

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]
            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N,H,W)).long())
            targetr = Variable(torch.ones((N,H,W)).long())
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmask_l.float()))
            optimD.zero_grad()
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth)*nn.NLLLoss2d()(confr,targetr) #  mask_l
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr,targetf) # targetf mask_l
            else:
                LDr = nn.NLLLoss2d()(confr, targetr) # targetr mask_l
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff,mask_l) # targetf
            LDf.backward()
            LDr_d = LDr.data
            # LDr_c = LDr.cpu().numpy()
            LDf_d = LDf.data
            # LDf_c = LDf.cpu().numpy()
            LD_d = LDr_d + LDf_d
            # LD_c = LDr_c + LDf_c
            optimD = poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            #####################################
            #  labelled data Generator Training #
            #####################################
            optimG.zero_grad()
            optimD.zero_grad()
            cpmap = generator(img_l)
            # print("cpmap: ", cpmap.size())
            cpmapsmax = nn.Softmax2d()(cpmap)
            # print("cpmapsmax: ", cpmapsmax.size())
            cpmaplsmax = nn.LogSoftmax()(cpmap)
            # print("cpmaplsmax: ", cpmaplsmax.size())
            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax,mask_l)
            LGadv = nn.NLLLoss2d()(conff,targetr) # targetr mask_l
            LGadv_d = LGadv.data
            # LGadv_c = LGadv.cpu().numpy()
            LGce_d = LGce.data
            # LGce_c = LGce.cpu().numpy()

            # LGsemi_d = 0 # No semi-supervised training
            LGadv = args.lam_adv*LGadv
            (LGce + LGadv).backward()
            optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()
            LGseg_d = LGce_d + LGadv_d + LGsemi_d
            # LGseg_c = LGce_c + LGadv_c + LGsemi_c

            # training_log = [epoch,itr,LD_c,LDr_c,LDf_c,LGseg_c,LGce_c,LGadv_c,LGsemi_c]
            print("[{}][{}] LD: {:.4f} LD_fake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
                    .format(epoch,itr,LD_d,LDr_d,LDf_d,LGseg_d,LGce_d,LGadv_d,LGsemi_d))

        # best_miou = snapshot(generator,valoader,epoch,best_miou,args.snapshot_dir,args.prefix)
        best_miou, best_eiou = snapshot_segdis(generator,discriminator,valoader,epoch,best_miou,best_eiou,args.snapshot_dir,args.prefix)
Exemplo n.º 4
0
def train_adv(generator,discriminator,optimG,optimD,trainloader,valoader,args,ws):
    best_eiou = -1
    best_miou = -1
    for epoch in range(args.start_epoch,args.max_epoch+1):
        generator.train()
        for batch_id, (img,mask,ohmask,_) in enumerate(trainloader):
            if args.nogpu:
                img,mask,ohmask = Variable(img),Variable(mask),Variable(ohmask)
            else:
                img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda()),Variable(ohmask.cuda())
            itr = len(trainloader)*(epoch-1) + batch_id
            # generator forward
            cpmap = generator(Variable(img.data,volatile=True))
            cpmap = nn.Softmax2d()(cpmap)
            # print("cpmap: ", cpmap.size(), " ohmask: ", ohmask.size())

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]
            # print("cpmap: ", cpmap.size())

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N,H,W)).long(),requires_grad=False)
            targetr = Variable(torch.ones((N,H,W)).long(),requires_grad=False)
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            ##########################
            # DISCRIMINATOR TRAINING #
            ##########################
            optimD.zero_grad()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmask.float()))
            # print("confr: ", confr.size())
            # print("targetr: ", targetr.size())
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth)*nn.NLLLoss2d()(confr,targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr,targetf)
            else:
                LDr = nn.NLLLoss2d()(confr,targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff,targetf)
            LDf.backward()

            optimD = poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ######################
            # GENERATOR TRAINING #
            #####################
            optimG.zero_grad()
            optimD.zero_grad()
            cmap = generator(img)
            cpmapsmax = nn.Softmax2d()(cmap)
            cpmaplsmax = nn.LogSoftmax()(cmap)
            conff = nn.LogSoftmax()(discriminator(cpmapsmax))


            LGce = nn.NLLLoss2d()(cpmaplsmax,mask)
            LGadv = nn.NLLLoss2d()(conff,targetr)
            LGseg = LGce + args.lam_adv *LGadv

            LGseg.backward()
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            print("[{}][{}] LD: {:.4f} LDfake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f}"  \
                .format(epoch,itr,(LDr + LDf).data,LDr.data,LDf.data,LGseg.data,LGce.data,LGadv.data))
                    # .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg.data[0],LGce.data[0],LGadv.data[0]))
        # best_miou = snapshot(generator,valoader,epoch,best_miou,args.snapshot_dir,args.prefix)
        best_miou, best_eiou = snapshot_segdis(generator,discriminator,valoader,epoch,best_miou,best_eiou,args.snapshot_dir,args.prefix)
Exemplo n.º 5
0
def train_semi_m(generator, discriminator, optimG, optimD, trainloader_l,
                 trainloader_u, valoader, args):
    best_miou = -1
    # bce_loss = BCEWithLogitsLoss2d()
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        if epoch > args.wait_semi:
            """
            Using Discriminator ouput to train generator
            """
            loss_semi_value = 0
            for batch_id, (img, mask, ohmask) in enumerate(trainloader_u):
                if args.nogpu:
                    img, mask, ohmask = Variable(img), Variable(
                        mask), Variable(ohmask)
                else:
                    img, mask, ohmask = Variable(img.cuda()), Variable(
                        mask.cuda()), Variable(ohmask.cuda())
                # generator output
                cpmap = generator(Variable(img.data, volatile=True))
                # cpmap.detach()
                cpmapsmax = nn.Softmax2d()(
                    cpmap)  # torch.Size([4, 2, 320, 320])

                # discriminator output
                conf = discriminator(cpmapsmax)
                # confsmax = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)
                confsmax = nn.Softmax2d()(conf)
                conflsmax = nn.LogSoftmax()(
                    conf)  # torch.Size([4, 2, 320, 320])

                N = cpmap.size()[0]
                H = cpmap.size()[2]
                W = cpmap.size()[3]
                # Adversarial Loss
                targetr = Variable(torch.ones((N, H, W)).long())
                if not args.nogpu:
                    targetr = targetr.cuda()
                LGadv = nn.NLLLoss2d()(conflsmax, targetr)
                LGadv_d = LGadv.data

                # D_output = confsmax.data.cpu().numpy().squeeze(axis=1)
                D_output = torch.max(confsmax, 1)[1].cpu().numpy()
                # print("D_output: ", D_output.shape)
                semi_ignore_mask = (D_output < args.t_semi)
                semi_gt = cpmapsmax.data.cpu().numpy().argmax(axis=1)
                semi_gt[semi_ignore_mask] = 255  # 255
                semi_ratio = 1.0 - float(
                    semi_ignore_mask.sum()) / semi_ignore_mask.size
                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.FloatTensor(semi_gt)
                    LG_semi = args.lam_semi * loss_calc(cpmapsmax, semi_gt)
                    LG_semi = LG_semi  # /args.iter_size
                    # loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                    loss_semi_value += LG_semi.data.cpu().numpy(
                    ) / args.lam_semi
                    # LG_semi += LGadv_d
                    LG_semi.backward()

                # LGsemi.backward()
                # optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
                # optimG.step()

        for batch_id, (img, mask, ohmask) in enumerate(trainloader_l):
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            # itr = len(trainloader)*(epoch-1) + batch_id
            itr = (len(trainloader_u) + len(trainloader_l)) * (epoch -
                                                               1) + batch_id
            # generator forward
            cpmap = generator(Variable(img.data, volatile=True))
            cpmap = nn.Softmax2d()(cpmap)
            # print("cpmap: ", cpmap.size(), " ohmask: ", ohmask.size())

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]
            # print("cpmap: ", cpmap.size())

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)).long(),
                               requires_grad=False)
            targetr = Variable(torch.ones((N, H, W)).long(),
                               requires_grad=False)
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            ##########################
            # DISCRIMINATOR TRAINING #
            ##########################
            optimD.zero_grad()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmask.float()))
            # print("confr: ", confr.size())
            # print("targetr: ", targetr.size())
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                 targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
            else:
                LDr = nn.NLLLoss2d()(confr, targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff, targetf)
            LDf.backward()

            optimD = poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ######################
            # GENERATOR TRAINING #
            #####################
            optimG.zero_grad()
            optimD.zero_grad()
            cmap = generator(img)
            cpmapsmax = nn.Softmax2d()(cmap)
            cpmaplsmax = nn.LogSoftmax()(cmap)
            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax, mask)
            LGadv = nn.NLLLoss2d()(conff, targetr)
            LGseg = LGce + args.lam_adv * LGadv

            LGseg.backward()
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            print("[{}][{}] LD: {:.4f} LDfake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f}, LG_semi: {:.5f}"  \
                .format(epoch,itr,(LDr + LDf).data,LDr.data,LDf.data,LGseg.data,LGce.data,LGadv.data,LG_semi.data))
            # .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg.data[0],LGce.data[0],LGadv.data[0]))
        # best_miou = snapshot(generator,valoader,epoch,best_miou,args.snapshot_dir,args.prefix)
        best_miou = snapshot_segdis(generator, discriminator, valoader, epoch,
                                    best_miou, args.snapshot_dir, args.prefix)
Exemplo n.º 6
0
def train_semi(generator, discriminator, optimG, optimD, trainloader_l,
               trainloader_u, valoader, args):
    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        trainloader_l_iter = iter(trainloader_l)
        trainloader_u_iter = iter(trainloader_u)
        print("Epoch: {}".format(epoch))
        batch_id = 0
        # Randomly pick labeled or unlabeled data for training
        while (True):
            if random.random() < 0.5:
                loader = trainloader_l_iter
                labeled = True
            else:
                loader = trainloader_u_iter
                labeled = False
            # Check if the loader has a batch available
            try:
                img, mask, ohmask = next(loader)
            except:
                # Curr loader doesn't have data
                if labeled:
                    loader = trainloader_u_iter
                    labeled = False
                else:
                    loader = trainloader_l_iter
                    labeled = True

                # Check if the new loader has data
                try:
                    img, mask, ohmask = next(loader)
                except:
                    # Boith loaders exhausted
                    break

            batch_id += 1
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            itr = (len(trainloader_u) + len(trainloader_l)) * (epoch -
                                                               1) + batch_id
            LGseg_d = 0
            if epoch < args.wait_semi:
                ################################################
                #  Labelled data for Discriminator Training #
                ################################################
                cpmap = generator(Variable(img.data, volatile=True))
                cpmap = nn.Softmax2d()(cpmap)

                N = cpmap.size()[0]
                H = cpmap.size()[2]
                W = cpmap.size()[3]

                # Generate the Real and Fake Labels
                targetf = Variable(torch.zeros((N, H, W)).long())
                targetr = Variable(torch.ones((N, H, W)).long())
                if not args.nogpu:
                    targetf = targetf.cuda()
                    targetr = targetr.cuda()

                # Train on Real
                confr = nn.LogSoftmax()(discriminator(ohmask.float()))

                optimD.zero_grad()

                if args.d_label_smooth != 0:
                    LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                     targetr)
                    LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
                else:
                    LDr = nn.NLLLoss2d()(confr, targetr)
                LDr.backward()

                # Train on Fake
                conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
                LDf = nn.NLLLoss2d()(conff, targetf)
                LDf.backward()

                LDr_d = LDr.data
                LDf_d = LDf.data
                LD_d = LDr_d + LDf_d
                optimD = poly_lr_scheduler(optimD, args.d_lr, itr)
                optimD.step()

                #####################################
                #  labelled data Generator Training #
                #####################################
                optimG.zero_grad()
                optimD.zero_grad()
                cpmap = generator(img)
                cpmapsmax = nn.Softmax2d()(cpmap)
                cpmaplsmax = nn.LogSoftmax()(cpmap)

                conff = nn.LogSoftmax()(discriminator(cpmapsmax))

                LGce = nn.NLLLoss2d()(cpmaplsmax, mask)
                LGadv = nn.NLLLoss2d()(conff, targetr)

                LGadv_d = LGadv.data
                LGce_d = LGce.data
                LGsemi_d = 0  # No semi-supervised training

                LGadv = args.lam_adv * LGadv

                (LGce + LGadv).backward()
                optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
                optimG.step()
                LGseg_d = LGce_d + args.lam_adv * LGadv_d + args.lam_semi * LGsemi_d
            else:
                if labeled:
                    ################################################
                    #  Labelled data for Discriminator Training #
                    ################################################
                    cpmap = generator(Variable(img.data, volatile=True))
                    cpmap = nn.Softmax2d()(cpmap)

                    N = cpmap.size()[0]
                    H = cpmap.size()[2]
                    W = cpmap.size()[3]

                    # Generate the Real and Fake Labels
                    targetf = Variable(torch.zeros((N, H, W)).long())
                    targetr = Variable(torch.ones((N, H, W)).long())
                    if not args.nogpu:
                        targetf = targetf.cuda()
                        targetr = targetr.cuda()

                    # Train on Real
                    confr = nn.LogSoftmax()(discriminator(ohmask.float()))

                    optimD.zero_grad()

                    if args.d_label_smooth != 0:
                        LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(
                            confr, targetr)
                        LDr += args.d_label_smooth * nn.NLLLoss2d()(confr,
                                                                    targetf)
                    else:
                        LDr = nn.NLLLoss2d()(confr, targetr)
                    LDr.backward()

                    # Train on Fake
                    conff = nn.LogSoftmax()(discriminator(Variable(
                        cpmap.data)))
                    LDf = nn.NLLLoss2d()(conff, targetf)
                    LDf.backward()

                    LDr_d = LDr.data
                    LDf_d = LDf.data
                    LD_d = LDr_d + LDf_d
                    optimD = poly_lr_scheduler(optimD, args.d_lr, itr)
                    optimD.step()

                    #####################################
                    #  labelled data Generator Training #
                    #####################################
                    optimG.zero_grad()
                    optimD.zero_grad()
                    cpmap = generator(img)
                    cpmapsmax = nn.Softmax2d()(cpmap)
                    cpmaplsmax = nn.LogSoftmax()(cpmap)

                    conff = nn.LogSoftmax()(discriminator(cpmapsmax))

                    LGce = nn.NLLLoss2d()(cpmaplsmax, mask)
                    LGadv = nn.NLLLoss2d()(conff, targetr)

                    LGadv_d = LGadv.data
                    LGce_d = LGce.data
                    LGsemi_d = 0  # No semi-supervised training

                    LGadv = args.lam_adv * LGadv

                    (LGce + LGadv).backward()
                    optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
                    optimG.step()
                    LGseg_d = LGce_d + args.lam_adv * LGadv_d + args.lam_semi * LGsemi_d

                else:
                    #####################################
                    # Use unlabelled data to get L_semi #
                    #####################################
                    # No discriminator training
                    LD_d = 0
                    LDr_d = 0
                    LDf_d = 0
                    # Init all loss to 0 for logging ease
                    LGsemi_d = 0
                    LGce_d = 0
                    LGadv_d = 0
                    optimG.zero_grad()
                    # if epoch > args.wait_semi:
                    cpmap = generator(img)
                    cpmapsmax = nn.Softmax2d()(cpmap)

                    conf = discriminator(cpmapsmax)
                    confsmax = nn.Softmax2d()(conf)
                    conflsmax = nn.LogSoftmax()(conf)

                    N = cpmap.size()[0]
                    H = cpmap.size()[2]
                    W = cpmap.size()[3]

                    # Adversarial Loss
                    targetr = Variable(torch.ones((N, H, W)).long())
                    if not args.nogpu:
                        targetr = targetr.cuda()
                    LGadv = nn.NLLLoss2d()(conflsmax, targetr)
                    LGadv_d = LGadv.data
                    # Semi-Supervised Loss

                    hardpred = torch.max(cpmapsmax, 1)[1].squeeze(1)

                    idx = np.zeros(cpmap.data.cpu().numpy().shape,
                                   dtype=np.uint8)
                    idx = idx.transpose(0, 2, 3, 1)

                    confnp = confsmax[:, 1, ...].data.cpu().numpy()
                    hardprednp = hardpred.data.cpu().numpy()
                    idx[confnp > args.t_semi] = np.identity(
                        2, dtype=idx.dtype)[hardprednp[confnp > args.t_semi]]

                    LG = args.lam_adv * LGadv
                    if np.count_nonzero(idx) != 0:
                        cpmaplsmax = nn.LogSoftmax()(cpmap)
                        idx = Variable(
                            torch.from_numpy(idx.transpose(0, 3, 1,
                                                           2)).byte().cuda())
                        # LGsemi_arr = cpmaplsmax.masked_select(idx)
                        LGsemi_arr = cpmaplsmax.masked_select(idx.bool())
                        LGsemi = -1 * LGsemi_arr.mean()
                        LGsemi_d = LGsemi.data
                        LGsemi = args.lam_semi * LGsemi
                        LG += LGsemi

                    LG.backward()
                    optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
                    optimG.step()
                    # Manually free all variables. Look into details of how variables are freed
                    del idx
                    del confnp
                    del confsmax
                    del hardpred
                    del hardprednp
                    del cpmapsmax
                    del cpmap
                    LGseg_d = LGce_d + args.lam_adv * LGadv_d + args.lam_semi * LGsemi_d

            # Manually free memory! Later, really understand how computation graphs free variables
            # LD: Discriminator loss LD_fake: fake Dis los LD_real: real Dis loss LG_ce: generator loss
            # Ladv: Adversarail loss Lsemi: semi loss
            print("[{}][{}] LD: {:.4f} LD_fake: {:.4f} LD_real: {:.4f} LGseg: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
                    .format(epoch,itr,LD_d,LDf_d,LDr_d,LGseg_d,LGce_d,LGadv_d,LGsemi_d))
        # best_miou = snapshot(generator,valoader,epoch,best_miou,args.snapshot_dir,args.prefix)
        best_miou = snapshot_segdis(generator, discriminator, valoader, epoch,
                                    best_miou, args.snapshot_dir, args.prefix)
Exemplo n.º 7
0
def train_label_correction(generator, optimG, trainloader, valoader, args):
    best_miou = -1
    best_eiou = -1
    # FILE_INFO = file_info()

    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img_l, mask_l, ohmask_l, _,
                       imgs_org) in enumerate(trainloader):

            itr = len(trainloader) * (epoch - 1) + batch_id
            if args.nogpu:
                img_l, mask_l, ohmask_l = Variable(img_l), Variable(
                    mask_l), Variable(ohmask_l)
            else:
                img_l, mask_l, ohmask_l = Variable(img_l.cuda()), Variable(
                    mask_l.cuda()), Variable(ohmask_l.cuda())
            imgs_array = imgs_org.numpy()

            LGcorr = torch.tensor(0)
            loss_corr_value = 0
            # begin label correction

            ################################################
            #  train labelled data                         #
            ################################################
            cpmap = generator(img_l)
            cpmapsmax = nn.Softmax2d()(cpmap)
            cpmaplsmax = nn.LogSoftmax()(cpmap)
            LGce = nn.NLLLoss()(cpmaplsmax, mask_l)
            LGce_d = LGce.data
            optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.zero_grad()
            LGce.backward()
            optimG.step()

            ################################################
            #  label correction                            #
            ################################################
            if epoch > args.wait_semi:
                # Check if the loader has a batch available

                # semi unlabelled training
                cpmap = generator(img_l)
                cpmapsmax = nn.Softmax2d()(cpmap)
                cpmaplsmax = nn.LogSoftmax()(cpmap)

                soft_preds_np = cpmapsmax.data.cpu().numpy()
                soft_mask = np.zeros(
                    (soft_preds_np.shape[0], soft_preds_np.shape[2],
                     soft_preds_np.shape[3]))
                for i in range(soft_preds_np.shape[0]):
                    soft_pred = soft_preds_np[i]
                    img_array = imgs_array[i]
                    img_array = cv2.resize(img_array, (320, 320),
                                           interpolation=cv2.INTER_AREA)
                    hard_pred = np.argmax(soft_pred, axis=0).astype(np.uint8)
                    rois_crf, rois = get_roi_crf(img_array, soft_pred,
                                                 hard_pred)
                    for roi_crf, roi in zip(rois_crf, rois):
                        # print(roi)
                        xmin, ymin, xmax, ymax = roi[0], roi[1], roi[2], roi[3]
                        soft_pred[:, ymin:ymax, xmin:xmax] = roi_crf
                    crf_result = np.argmax(soft_pred, axis=0).astype(np.uint8)
                    soft_mask[i] = crf_result
                mask_c = Variable(torch.tensor(soft_mask)).type(
                    torch.LongTensor).cuda()
                LGcorr = nn.NLLLoss()(cpmaplsmax, mask_c)
                optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
                optimG.zero_grad()
                LGcorr.backward()
                optimG.step()

            LGcorr_d = LGcorr.data
            LGseg_d = LGce_d + args.lam_adv * LGcorr
            print("[{}][{}] LGseg_d: {:.4f} LGce_d: {:.4f} LG_corr: {:.4f}"\
                    .format(epoch,itr,LGseg_d,LGce_d,LGcorr_d))

        # best_miou = snapshot(generator,valoader,epoch,best_miou,args.snapshot_dir,args.prefix)
        best_miou, best_eiou = snapshote(generator, valoader, epoch, best_miou,
                                         best_eiou, args.snapshot_dir,
                                         args.prefix)
Exemplo n.º 8
0
def train_semi(generator, discriminator, optimG, optimD, trainloader_l,
               trainloader_u, valoader, args):

    best_miou = -1
    best_eiou = -1
    FILE_INFO = file_info()
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()

        trainloader_l_iter = iter(trainloader_l)
        trainloader_u_iter = iter(trainloader_u)
        print("Epoch: {}".format(epoch))
        batch_id = 0
        while (True):
            batch_id += 1
            itr = (len(trainloader_u) + len(trainloader_l)) * (epoch -
                                                               1) + batch_id
            optimG.zero_grad()
            optimD.zero_grad()
            optimD = poly_lr_scheduler(optimD, args.d_lr, itr)
            optimG = poly_lr_scheduler(optimG, args.g_lr, itr)

            loader_l = trainloader_l_iter
            try:
                img_l, mask_l, ohmask_l, _ = next(loader_l)
            except:
                break

            if args.nogpu:
                img_l, mask_l, ohmask_l = Variable(img_l), Variable(
                    mask_l), Variable(ohmask_l)
            else:
                img_l, mask_l, ohmask_l = Variable(img_l.cuda()), Variable(
                    mask_l.cuda()), Variable(ohmask_l.cuda())

            LGsemi_d = 0
            loss_semi_value = 0
            if epoch > args.wait_semi:
                # Check if the loader has a batch available

                # semi unlabelled training
                cpmap = generator(img_l)
                cpmapsmax = nn.Softmax2d()(cpmap)
                cpmaplsmax = nn.LogSoftmax()(cpmap)
                # confsmax = nn.Softmax2d()(conf)
                conf = discriminator(cpmapsmax)
                conf_softmax = nn.Softmax2d()(conf)
                conf_softmax_remain = conf_softmax.detach().squeeze(1)
                conf_soft = conf_softmax_remain.data.cpu().numpy()

                semi_gt = np.argmax(conf_soft, axis=1).astype(np.uint8)
                semi_ratio = 1.0 - float(semi_gt.sum()) / semi_gt.size
                # print('semi ratio: {:.4f}'.format(semi_ratio))
                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.LongTensor(semi_gt).cuda()
                    LGsemi = args.lam_semi * nn.NLLLoss()(cpmaplsmax, semi_gt)
                    LGsemi_d += LGsemi.data.cpu().numpy() / args.lam_semi
                    # LGsemi += LGadv_d
                    LGsemi.backward()

            ################################################
            #  train labelled data                         #
            ################################################
            cpmap = generator(img_l)
            cpmapsmax = nn.Softmax2d()(cpmap)
            cpmaplsmax = nn.LogSoftmax()(cpmap)
            conff = nn.LogSoftmax()(discriminator(cpmapsmax))
            LGce = nn.NLLLoss()(cpmaplsmax, mask_l)
            # LGce = loss_calc(cpmaplsmax, mask_l)

            # LGadv = nn.BCELoss()(conff,mask_l)
            # LGadv = loss_calc(conff, mask_l)
            LGadv = nn.NLLLoss()(conff, mask_l)
            LGadv_d = LGadv.data
            LGce_d = LGce.data
            # print("{}: {}".format(str(FILE_INFO), conff.size()))

            # LGsemi_d = 0 # No semi-supervised training
            LGadv = args.lam_adv * LGadv
            (LGce + LGadv).backward()
            # optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()
            optimD.step()

            LGseg_d = LGce_d + LGadv_d + LGsemi_d
            # LGseg_c = LGce_c + LGadv_c + LGsemi_c

            # training_log = [epoch,itr,LD_c,LDr_c,LDf_c,LGseg_c,LGce_c,LGadv_c,LGsemi_c]
            print("[{}][{}] LGseg_d: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
                    .format(epoch,itr,LGseg_d,LGce_d,LGadv_d,LGsemi_d))

        # best_miou = snapshot(generator,valoader,epoch,best_miou,args.snapshot_dir,args.prefix)
        best_miou, best_eiou = snapshot_segdis(generator, discriminator,
                                               valoader, epoch, best_miou,
                                               best_eiou, args.snapshot_dir,
                                               args.prefix)
Exemplo n.º 9
0
def train_semi_b(generator, discriminator, optimG, optimD, trainloader_l,
                 trainloader_u, valoader, args):

    best_miou = -1
    best_eiou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()

        trainloader_l_iter = iter(trainloader_l)
        trainloader_u_iter = iter(trainloader_u)
        print("Epoch: {}".format(epoch))
        batch_id = 0
        while (True):
            batch_id += 1
            itr = (len(trainloader_u) + len(trainloader_l)) * (epoch -
                                                               1) + batch_id
            optimG.zero_grad()
            optimD.zero_grad()
            optimD = poly_lr_scheduler(optimD, args.d_lr, itr)
            optimG = poly_lr_scheduler(optimG, args.g_lr, itr)

            LGsemi_d = 0
            if epoch > args.wait_semi:
                if random.random() < 0.5:
                    loader_l = trainloader_l_iter
                    loader_u = trainloader_u_iter
                else:
                    loader_l = trainloader_u_iter
                    loader_u = trainloader_l_iter
                # Check if the loader has a batch available
                try:
                    img_u, mask_u, ohmask_u, _ = next(loader_u)
                except:
                    trainloader_u_iter = iter(trainloader_u)
                    loader_u = trainloader_u_iter
                    img_u, mask_u, ohmask_u, _ = next(loader_u)

                if args.nogpu:
                    img_u, mask_u, ohmask_u = Variable(img_u), Variable(
                        mask_u), Variable(ohmask_u)
                else:
                    img_u, mask_u, ohmask_u = Variable(img_u.cuda()), Variable(
                        mask_u.cuda()), Variable(ohmask_u.cuda())

                # semi unlabelled training
                cpmap = generator(img_u)
                cpmapsmax = nn.Softmax2d()(cpmap)
                # confsmax = nn.Softmax2d()(conf)
                conf = discriminator(cpmapsmax)
                confsigmoid = nn.Sigmoid()(conf)
                confsigmoid_remain = confsigmoid.detach().squeeze(1)
                # conflsmax = nn.LogSoftmax()(conf)

                N = cpmap.size()[0]
                H = cpmap.size()[2]
                W = cpmap.size()[3]

                semi_gt = np.int8(
                    (confsigmoid_remain.cpu().numpy() >= args.t_semi))
                # semi_gt = cpmap.data.cpu().numpy().argmax(axis=1)
                # semi_gt[semi_ignore_mask] = 1
                semi_ratio = 1.0 - float(semi_gt.sum()) / semi_gt.size
                # print('semi ratio: {:.4f}'.format(semi_ratio))
                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.FloatTensor(semi_gt)
                    LGsemi = args.lam_semi * loss_calc(cpmapsmax, semi_gt)
                    LGsemi_d += LGsemi.data.cpu().numpy() / args.lam_semi
                    # LGsemi += LGadv_d
                    LGsemi.backward()

            ################################################
            #  train labelled data                         #
            ################################################
            loader_l = trainloader_l_iter
            loader_u = trainloader_u_iter
            try:
                if random.random() < 0.5:
                    img_l, mask_l, ohmask_l, _ = next(loader_l)
                else:
                    img_l, mask_l, ohmask_l, _ = next(loader_u)
            except:
                break
            if args.nogpu:
                img_l, mask_l, ohmask_l = Variable(img_l), Variable(
                    mask_l), Variable(ohmask_l)
            else:
                img_l, mask_l, ohmask_l = Variable(img_l.cuda()), Variable(
                    mask_l.cuda()), Variable(ohmask_l.cuda())

            #####################################
            #  labelled data Generator Training #
            #####################################
            # optimG.zero_grad()
            # optimD.zero_grad()
            cpmap = generator(img_l)
            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]
            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)))  # .long()
            targetr = Variable(torch.ones((N, H, W)))  # .long()
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            cpmapsmax = nn.Softmax2d()(cpmap)
            cpmaplsmax = nn.LogSoftmax()(cpmap)
            conff = nn.Sigmoid()(discriminator(cpmapsmax))
            print("cpmaplsmax: ", cpmaplsmax.size())
            print("mask_l: ", mask_l.size())
            print("conff: ", conff.size())
            # LGce = loss_calc(cpmaplsmax, mask_l)
            LGce = nn.NLLLoss()(cpmaplsmax, mask_l)

            # LGadv = nn.BCELoss()(conff,mask_l)
            LGadv = loss_calc(cpmaplsmax, mask_l)
            LGadv_d = LGadv.data
            # LGadv_c = LGadv.cpu().numpy()
            LGce_d = LGce.data
            # LGce_c = LGce.cpu().numpy()

            # LGsemi_d = 0 # No semi-supervised training
            LGadv = args.lam_adv * LGadv
            (LGce + LGadv).backward()
            # optimG = poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()
            optimD.step()

            LGseg_d = LGce_d + LGadv_d + LGsemi_d
            # LGseg_c = LGce_c + LGadv_c + LGsemi_c

            # training_log = [epoch,itr,LD_c,LDr_c,LDf_c,LGseg_c,LGce_c,LGadv_c,LGsemi_c]
            print("[{}][{}] LGseg_d: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
                    .format(epoch,itr,LGseg_d,LGce_d,LGadv_d,LGsemi_d))

        # best_miou = snapshot(generator,valoader,epoch,best_miou,args.snapshot_dir,args.prefix)
        best_miou, best_eiou = snapshot_segdis(generator, discriminator,
                                               valoader, epoch, best_miou,
                                               best_eiou, args.snapshot_dir,
                                               args.prefix)
def train_semi(args):
    # TODO: Make it more generic to include for other splits
    args.batch_size = args.batch_size * 2

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)
    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    discriminator = Dis(in_channels=21)
    if args.d_optim == 'adam':
        optimD = optim.Adam(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr = args.d_lr)
    else:
        optimD = optim.SGD(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

    if not args.nogpu:
        discriminator = nn.DataParallel(discriminator).cuda()

    ############
    # TRAINING #
    ############
    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, ohmask) in enumerate(trainloader):
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            itr = len(trainloader) * (epoch - 1) + batch_id
            ## TODO: Extend random interleaving for split of any size
            mid = args.batch_size // 2
            img1, mask1, ohmask1 = img[0:mid, ...], mask[0:mid,
                                                         ...], ohmask[0:mid,
                                                                      ...]
            img2, mask2, ohmask2 = img[mid:, ...], mask[mid:,
                                                        ...], ohmask[mid:, ...]

            # Random Interleaving
            if random.random() < 0.5:
                imgl, maskl, ohmaskl = img1, mask1, ohmask1
                imgu, masku, ohmasku = img2, mask2, ohmask2
            else:
                imgu, masku, ohmasku = img1, mask1, ohmask1
                imgl, maskl, ohmaskl = img2, mask2, ohmask2

            ################################################
            #  Labelled data for Discriminator Training #
            ################################################
            cpmap = generator(Variable(imgl.data, volatile=True))
            cpmap = nn.Softmax2d()(cpmap)

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)).long())
            targetr = Variable(torch.ones((N, H, W)).long())
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmaskl.float()))
            optimD.zero_grad()
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                 targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
            else:
                LDr = nn.NLLLoss2d()(confr, targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff, targetf)
            LDf.backward()

            poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ###########################################
            #  labelled data Generator Training       #
            ###########################################
            optimG.zero_grad()

            cpmap = generator(imgl)
            cpmapsmax = nn.Softmax2d()(cpmap)
            cpmaplsmax = nn.LogSoftmax()(cpmap)

            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax, maskl)
            LGadv = nn.NLLLoss2d()(conff, targetr)

            LGadv_d = LGadv.data[0]
            LGce_d = LGce.data[0]

            LGadv = args.lam_adv * LGadv

            (LGce + LGadv).backward()
            #####################################
            # Use unlabelled data to get L_semi #
            #####################################
            LGsemi_d = 0
            if epoch > args.wait_semi:

                cpmap = generator(imgu)
                softpred = nn.Softmax2d()(cpmap)
                hardpred = torch.max(softpred, 1)[1].squeeze(1)
                conf = nn.Softmax2d()(discriminator(
                    Variable(softpred.data, volatile=True)))

                idx = np.zeros(cpmap.data.cpu().numpy().shape, dtype=np.uint8)
                idx = idx.transpose(0, 2, 3, 1)

                confnp = cpmap[:, 1, ...].data.cpu().numpy()
                hardprednp = hardpred.data.cpu().numpy()
                idx[confnp > args.t_semi] = np.identity(
                    21, dtype=idx.dtype)[hardprednp[confnp > args.t_semi]]

                if np.count_nonzero(idx) != 0:
                    cpmaplsmax = nn.LogSoftmax()(cpmap)
                    idx = Variable(torch.from_numpy(idx).byte().cuda())
                    LGsemi_arr = cpmaplsmax.masked_select(idx)
                    LGsemi = -1 * LGsemi_arr.mean()
                    LGsemi_d = LGsemi.data[0]
                    LGsemi = args.lam_semi * LGsemi
                    LGsemi.backward()
                else:
                    LGsemi_d = 0
                LGseg_d = LGce_d + LGadv_d + LGsemi_d

                del idx
                del conf
                del confnp
                del hardpred
                del softpred
                del hardprednp
                del cpmap
            LGseg_d = LGce_d + LGadv_d + LGsemi_d
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            # Manually free memory! Later, really understand how computation graphs free variables

            print("[{}][{}] LD: {:.4f} LD_fake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f} LG_semi: {:.4f}"\
                    .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg_d,LGce_d,LGadv_d,LGsemi_d))

        snapshot(generator, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)
def train_adv(args):
    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    #############
    # GENERATOR #
    #############
    generator = deeplabv2.ResDeeplab()
    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        generator = nn.DataParallel(generator).cuda()

    #################
    # DISCRIMINATOR #
    ################
    discriminator = Dis(in_channels=21)
    if args.d_optim == 'adam':
        optimD = optim.Adam(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr = args.d_lr)
    else:
        optimD = optim.SGD(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr=args.d_lr,weight_decay=0.0001,momentum=0.5,nesterov=True)

    if not args.nogpu:
        discriminator = nn.DataParallel(discriminator).cuda()

    #############
    # TRAINING  #
    #############
    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        generator.train()
        for batch_id, (img, mask, ohmask) in enumerate(trainloader):
            if args.nogpu:
                img, mask, ohmask = Variable(img), Variable(mask), Variable(
                    ohmask)
            else:
                img, mask, ohmask = Variable(img.cuda()), Variable(
                    mask.cuda()), Variable(ohmask.cuda())
            itr = len(trainloader) * (epoch - 1) + batch_id
            cpmap = generator(Variable(img.data, volatile=True))
            cpmap = nn.Softmax2d()(cpmap)

            N = cpmap.size()[0]
            H = cpmap.size()[2]
            W = cpmap.size()[3]

            # Generate the Real and Fake Labels
            targetf = Variable(torch.zeros((N, H, W)).long(),
                               requires_grad=False)
            targetr = Variable(torch.ones((N, H, W)).long(),
                               requires_grad=False)
            if not args.nogpu:
                targetf = targetf.cuda()
                targetr = targetr.cuda()

            ##########################
            # DISCRIMINATOR TRAINING #
            ##########################
            optimD.zero_grad()

            # Train on Real
            confr = nn.LogSoftmax()(discriminator(ohmask.float()))
            if args.d_label_smooth != 0:
                LDr = (1 - args.d_label_smooth) * nn.NLLLoss2d()(confr,
                                                                 targetr)
                LDr += args.d_label_smooth * nn.NLLLoss2d()(confr, targetf)
            else:
                LDr = nn.NLLLoss2d()(confr, targetr)
            LDr.backward()

            # Train on Fake
            conff = nn.LogSoftmax()(discriminator(Variable(cpmap.data)))
            LDf = nn.NLLLoss2d()(conff, targetf)
            LDf.backward()

            poly_lr_scheduler(optimD, args.d_lr, itr)
            optimD.step()

            ######################
            # GENERATOR TRAINING #
            #####################
            optimG.zero_grad()

            cmap = generator(img)
            cpmapsmax = nn.Softmax2d()(cmap)
            cpmaplsmax = nn.LogSoftmax()(cmap)
            conff = nn.LogSoftmax()(discriminator(cpmapsmax))

            LGce = nn.NLLLoss2d()(cpmaplsmax, mask)
            LGadv = nn.NLLLoss2d()(conff, targetr)
            LGseg = LGce + args.lam_adv * LGadv

            LGseg.backward()
            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.step()

            print("[{}][{}] LD: {:.4f} LDfake: {:.4f} LD_real: {:.4f} LG: {:.4f} LG_ce: {:.4f} LG_adv: {:.4f}"  \
                    .format(epoch,itr,(LDr + LDf).data[0],LDr.data[0],LDf.data[0],LGseg.data[0],LGce.data[0],LGadv.data[0]))
        snapshot(generator, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)
def train_base(args):

    #######################
    # Training Dataloader #
    #######################

    if args.no_norm:
        imgtr = [ToTensor()]
    else:
        imgtr = [ToTensor(), NormalizeOwn()]

    labtr = [IgnoreLabelClass(), ToTensorLabel()]
    cotr = [RandomSizedCrop((321, 321))]

    trainset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), label_transform=Compose(labtr), \
        co_transform=Compose(cotr))
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=2,
                             drop_last=True)

    #########################
    # Validation Dataloader #
    ########################
    if args.val_orig:
        if args.no_norm:
            imgtr = [ZeroPadding(), ToTensor()]
        else:
            imgtr = [ZeroPadding(), ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = []
    else:
        if args.no_norm:
            imgtr = [ToTensor()]
        else:
            imgtr = [ToTensor(), NormalizeOwn()]
        labtr = [IgnoreLabelClass(), ToTensorLabel()]
        cotr = [RandomSizedCrop((321, 321))]

    valset = PascalVOC(home_dir,args.dataset_dir,img_transform=Compose(imgtr), \
        label_transform = Compose(labtr),co_transform=Compose(cotr),train_phase=False)
    valoader = DataLoader(valset, batch_size=1)

    model = deeplabv2.ResDeeplab()
    init_weights(model, args.init_net)

    optimG = optim.SGD(filter(lambda p: p.requires_grad, \
        model.parameters()),lr=args.g_lr,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)

    if not args.nogpu:
        model = nn.DataParallel(model).cuda()

    best_miou = -1
    for epoch in range(args.start_epoch, args.max_epoch + 1):
        model.train()
        for batch_id, (img, mask, _) in enumerate(trainloader):

            if args.nogpu:
                img, mask = Variable(img), Variable(mask)
            else:
                img, mask = Variable(img.cuda()), Variable(mask.cuda())

            itr = len(trainloader) * (epoch - 1) + batch_id
            cprob = model(img)
            cprob = nn.LogSoftmax()(cprob)

            Lseg = nn.NLLLoss2d()(cprob, mask)

            poly_lr_scheduler(optimG, args.g_lr, itr)
            optimG.zero_grad()

            Lseg.backward()
            optimG.step()

            print("[{}][{}]Loss: {:0.4f}".format(epoch, itr, Lseg.data[0]))

        snapshot(model, valoader, epoch, best_miou, args.snapshot_dir,
                 args.prefix)