def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True


    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print (name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)


    model.train()
    model=nn.DataParallel(model)
    model.cuda()


    cudnn.benchmark = True

    # init D
    model_D = Discriminator_concat(num_classes=args.num_classes)
    if args.restore_from_D is not None:
         model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D = nn.DataParallel(model_D)
    model_D.train()
    model_D.cuda()


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)


    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                       scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data ==0:
        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)



    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)


    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = torch.nn.BCELoss()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')


    # labels for adversarial training
    pred_label = 0
    gt_label = 1


    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first

            # train with source

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels)

            D_gt = Variable(one_hot(labels)).cuda()

            D_out = model_D(torch.cat([F.softmax(pred,dim=1),D_gt,F.sigmoid(images)],1))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label,D_out))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred*5

            # proper normalization
            loss = loss/args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size


            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()




            D_out =model_D(torch.cat([F.softmax(pred,dim=1),D_gt,F.sigmoid(images)],1))
            loss_D = bce_loss(D_out, make_D_label(pred_label,D_out))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            img_gt, labels_gt, _, _ = batch
            img_gt=Variable(img_gt).cuda()
            pred_gt = interp(model(img_gt))
            pred_gt = pred_gt.detach()
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = model_D(torch.cat([D_gt_v,F.softmax(pred_gt,dim=1),F.sigmoid(img_gt)],1))
            loss_D = bce_loss(D_out, make_D_label(gt_label,D_out))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]



        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps-1:
            print( 'save model ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter!=0:
            print ('taking snapshot ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end-start,'seconds')
Exemplo n.º 2
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True


    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http' :
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in currendt model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print (name)
        if name in saved_state_dict and param.size() == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))

    model.load_state_dict(new_params)


    model.train()


    model=nn.DataParallel(model)
    model.cuda()


    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))

    model_D = nn.DataParallel(model_D)
    model_D.train()
    model_D.cuda()


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)


    train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
                       scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                        batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                        batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)


    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)


    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')


    # labels for adversarial training
    pred_label = 0
    gt_label = 1


    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):


            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda()


                pred = interp(model(images))
                pred_remain = pred.detach()

                mask1=F.softmax(pred,dim=1).data.cpu().numpy()

                id2 = np.argmax(mask1, axis=1)#10, 321, 321)


                D_out = interp(model_D(F.softmax(pred,dim=1)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)


                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv/args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)
                    #print semi_ignore_mask.shape 10,321,321

                    map2 = np.zeros([pred.size()[0], id2.shape[1], id2.shape[2]])
                    for k in  range(pred.size()[0]):
                        for i in range(id2.shape[1]):
                            for j in range(id2.shape[2]):
                               map2[k][i][j] = mask1[k][id2[k][i][j]][i][j]


                    semi_ignore_mask = (map2 <  0.999999)
                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(pred, semi_gt)
                        loss_semi = loss_semi/args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels)

            D_out = interp(model_D(F.softmax(pred,dim=1)))

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss/args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size


            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)


            D_out = interp(model_D(F.softmax(pred,dim=1)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]


            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D/args.iter_size/2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]



        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps-1:
            print( 'save model ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter!=0:
            print ('taking snapshot ...')
            torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'.pth'))
            torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end-start,'seconds')
def main():
    tag = 0

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model = nn.DataParallel(model)
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_DS = []
    for i in range(20):
        model_DS.append(Discriminator2_mul(num_classes=args.num_classes))
    # if args.restore_from_D is not None:
    #      model_D.load_state_dict(torch.load(args.restore_from_D))
    for model_D in model_DS:
        model_D = nn.DataParallel(model_D)
        model_D.train()
        model_D.cuda()

    model_D2 = Discriminator2(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D2.load_state_dict(torch.load(args.restore_from_D))
    model_D2 = nn.DataParallel(model_D2)
    model_D2.train()
    model_D2.cuda()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data == 0:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_DS = []
    for model_D in model_DS:
        optimizer_DS.append(
            optim.Adam(model_D.parameters(),
                       lr=args.learning_rate_D,
                       betas=(0.9, 0.99)))
    for optimizer_D in optimizer_DS:
        optimizer_D.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = torch.nn.BCELoss()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for optimizer_D in optimizer_DS:
            optimizer_D.zero_grad()
            adjust_learning_rate_D(optimizer_D, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for model_D in model_DS:
                for param in model_D.parameters():
                    param.requires_grad = False
            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))
            loss_seg = loss_calc(pred, labels)

            pred_re0 = F.softmax(pred, dim=1)
            pred_re = pred_re0.repeat(1, 3, 1, 1)
            indices_1 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([0])).cuda())
            indices_2 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([1])).cuda())
            indices_3 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([2])).cuda())
            img_re = torch.cat([
                indices_1.repeat(1, 21, 1, 1),
                indices_2.repeat(1, 21, 1, 1),
                indices_3.repeat(1, 21, 1, 1),
            ], 1)

            mul_img = pred_re * img_re  #10,63,321,321

            D_out_2 = model_D2(mul_img)

            loss_adv_pred_ = 0

            for i_l in range(labels.shape[0]):
                label_set = np.unique(labels[i_l]).tolist()
                for ls in label_set:
                    if ls != 0 and ls != 255:
                        ls = int(ls)

                        img_p = torch.cat([
                            mul_img[i_l][ls].unsqueeze(0).unsqueeze(0),
                            mul_img[i_l][ls + 21].unsqueeze(0).unsqueeze(0),
                            mul_img[i_l][ls + 21 +
                                         21].unsqueeze(0).unsqueeze(0)
                        ], 1)
                        #print 1,img_p.size()#(1L, 3L, 321L, 321L)
                        imgs = img_p
                        imgs1 = imgs[:, :, 0:107, 0:107]
                        imgs2 = imgs[:, :, 0:107, 107:214]
                        imgs3 = imgs[:, :, 0:107, 214:321]
                        imgs4 = imgs[:, :, 107:214, 0:107]
                        imgs5 = imgs[:, :, 107:214, 107:214]
                        imgs6 = imgs[:, :, 107:214, 214:321]
                        imgs7 = imgs[:, :, 214:321, 0:107]
                        imgs8 = imgs[:, :, 214:321, 107:214]
                        imgs9 = imgs[:, :, 214:321, 214:321]

                        #print 2, imgs1.size()#(1L, 3L, 107L, 107L)

                        img_ps = torch.cat([
                            imgs1, imgs2, imgs3, imgs4, imgs5, imgs6, imgs7,
                            imgs8, imgs9
                        ], 0)
                        #print 3, img_ps.size()#(9L, 3L, 107L, 107L)

                        D_out = model_DS[ls - 1](img_ps)
                        loss_adv_pred_ = loss_adv_pred_ + bce_loss(
                            D_out, make_D_label(gt_label, D_out))

            loss_adv_pred = loss_adv_pred_ * 0.5 + bce_loss(
                D_out_2, make_D_label(gt_label, D_out_2))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for model_D in model_DS:
                for param in model_D.parameters():
                    param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            pred_re0 = F.softmax(pred, dim=1)
            pred_re2 = pred_re0.repeat(1, 3, 1, 1)

            mul_img2 = pred_re2 * img_re

            D_out_2 = model_D2(mul_img2)

            loss_adv_pred_ = 0

            for i_l in range(labels.shape[0]):
                label_set = np.unique(labels[i_l]).tolist()
                for ls in label_set:
                    if ls != 0 and ls != 255:
                        ls = int(ls)

                        img_p = torch.cat([
                            mul_img2[i_l][ls].unsqueeze(0).unsqueeze(0),
                            mul_img2[i_l][ls + 21].unsqueeze(0).unsqueeze(0),
                            mul_img2[i_l][ls + 21 +
                                          21].unsqueeze(0).unsqueeze(0)
                        ], 1)
                        # print 1,img_p.size()#(1L, 3L, 321L, 321L)
                        imgs = img_p
                        imgs1 = imgs[:, :, 0:107, 0:107]
                        imgs2 = imgs[:, :, 0:107, 107:214]
                        imgs3 = imgs[:, :, 0:107, 214:321]
                        imgs4 = imgs[:, :, 107:214, 0:107]
                        imgs5 = imgs[:, :, 107:214, 107:214]
                        imgs6 = imgs[:, :, 107:214, 214:321]
                        imgs7 = imgs[:, :, 214:321, 0:107]
                        imgs8 = imgs[:, :, 214:321, 107:214]
                        imgs9 = imgs[:, :, 214:321, 214:321]

                        # print 2, imgs1.size()#(1L, 3L, 107L, 107L)

                        img_ps = torch.cat([
                            imgs1, imgs2, imgs3, imgs4, imgs5, imgs6, imgs7,
                            imgs8, imgs9
                        ], 0)
                        # print 3, img_ps.size()#(9L, 3L, 107L, 107L)

                        D_out = model_DS[ls - 1](img_ps)
                        loss_adv_pred_ = loss_adv_pred_ + bce_loss(
                            D_out, make_D_label(pred_label, D_out))

            loss_D = loss_adv_pred_ * 0.5 + bce_loss(
                D_out_2, make_D_label(pred_label, D_out_2))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            img_gt, labels_gt, _, name = batch
            img_gt = Variable(img_gt).cuda()
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            lb = D_gt_v.detach()

            pred_re3 = D_gt_v.repeat(1, 3, 1, 1)
            indices_1 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([0])).cuda())
            indices_2 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([1])).cuda())
            indices_3 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([2])).cuda())
            img_re3 = torch.cat([
                indices_1.repeat(1, 21, 1, 1),
                indices_2.repeat(1, 21, 1, 1),
                indices_3.repeat(1, 21, 1, 1),
            ], 1)

            #mul_img3 = img_re3
            mul_img3 = pred_re3 * img_re3

            D_out_2 = model_D2(mul_img3)

            loss_adv_pred_ = 0

            for i_l in range(labels_gt.shape[0]):
                label_set = np.unique(labels_gt[i_l]).tolist()

                for ls in label_set:
                    if ls != 0 and ls != 255:
                        ls = int(ls)

                        img_p = torch.cat([
                            mul_img3[i_l][ls].unsqueeze(0).unsqueeze(0),
                            mul_img3[i_l][ls + 21].unsqueeze(0).unsqueeze(0),
                            mul_img3[i_l][ls + 21 +
                                          21].unsqueeze(0).unsqueeze(0)
                        ], 1)
                        # print 1,img_p.size()#(1L, 3L, 321L, 321L)
                        imgs = img_p
                        imgs1 = imgs[:, :, 0:107, 0:107]
                        imgs2 = imgs[:, :, 0:107, 107:214]
                        imgs3 = imgs[:, :, 0:107, 214:321]
                        imgs4 = imgs[:, :, 107:214, 0:107]
                        imgs5 = imgs[:, :, 107:214, 107:214]
                        imgs6 = imgs[:, :, 107:214, 214:321]
                        imgs7 = imgs[:, :, 214:321, 0:107]
                        imgs8 = imgs[:, :, 214:321, 107:214]
                        imgs9 = imgs[:, :, 214:321, 214:321]

                        # print 2, imgs1.size()#(1L, 3L, 107L, 107L)

                        img_ps = torch.cat([
                            imgs1, imgs2, imgs3, imgs4, imgs5, imgs6, imgs7,
                            imgs8, imgs9
                        ], 0)
                        # print 3, img_ps.size()#(9L, 3L, 107L, 107L)

                        D_out = model_DS[ls - 1](img_ps)
                        loss_adv_pred_ = loss_adv_pred_ + bce_loss(
                            D_out, make_D_label(gt_label, D_out))
                        '''


                        if tag == 0:
                            # print lb[0].size()
                            # lb1=lb[0][0]
                            # lb2 = lb[0][1]
                            # lb3 = lb[0][2]
                            # lb4 = lb[0][3]
                            # lb5 = lb[0][4]
                            # lb6 = lb[0][5]
                            # lb7 = lb[0][0]
                            # lb8 = lb[0][0]
                            # lb9 = lb[0][0]
                            # lb10 = lb[0][0]


                            print label_set, name[0]
                            print ls
                            imgs = imgs.squeeze()
                            imgs = imgs.transpose(0, 1)
                            imgs = imgs.transpose(1, 2)

                            imgs1 = imgs1.squeeze()
                            imgs1 = imgs1.transpose(0, 1)
                            imgs1 = imgs1.transpose(1, 2)

                            imgs2 = imgs2.squeeze()
                            imgs2 = imgs2.transpose(0, 1)
                            imgs2 = imgs2.transpose(1, 2)

                            imgs3 = imgs3.squeeze()
                            imgs3 = imgs3.transpose(0, 1)
                            imgs3 = imgs3.transpose(1, 2)

                            imgs4 = imgs4.squeeze()
                            imgs4 = imgs4.transpose(0, 1)
                            imgs4 = imgs4.transpose(1, 2)

                            imgs5 = imgs5.squeeze()
                            imgs5 = imgs5.transpose(0, 1)
                            imgs5 = imgs5.transpose(1, 2)

                            imgs6 = imgs6.squeeze()
                            imgs6 = imgs6.transpose(0, 1)
                            imgs6 = imgs6.transpose(1, 2)

                            imgs7 = imgs7.squeeze()
                            imgs7 = imgs7.transpose(0, 1)
                            imgs7 = imgs7.transpose(1, 2)

                            imgs8 = imgs8.squeeze()
                            imgs8 = imgs8.transpose(0, 1)
                            imgs8 = imgs8.transpose(1, 2)

                            imgs9 = imgs9.squeeze()
                            imgs9 = imgs9.transpose(0, 1)
                            imgs9 = imgs9.transpose(1, 2)

                            imgs = imgs.data.cpu().numpy()
                            imgs1 = imgs1.data.cpu().numpy()
                            imgs2 = imgs2.data.cpu().numpy()
                            imgs3 = imgs3.data.cpu().numpy()
                            imgs4 = imgs4.data.cpu().numpy()
                            imgs5 = imgs5.data.cpu().numpy()
                            imgs6 = imgs6.data.cpu().numpy()
                            imgs7 = imgs7.data.cpu().numpy()
                            imgs8 = imgs8.data.cpu().numpy()
                            imgs9 = imgs9.data.cpu().numpy()
                            cv2.imwrite('/data1/wyc/1.png', imgs1)
                            cv2.imwrite('/data1/wyc/2.png', imgs2)
                            cv2.imwrite('/data1/wyc/3.png', imgs3)
                            cv2.imwrite('/data1/wyc/4.png', imgs4)
                            cv2.imwrite('/data1/wyc/5.png', imgs5)
                            cv2.imwrite('/data1/wyc/6.png', imgs6)
                            cv2.imwrite('/data1/wyc/7.png', imgs7)
                            cv2.imwrite('/data1/wyc/8.png', imgs8)
                            cv2.imwrite('/data1/wyc/9.png', imgs9)
                            cv2.imwrite('/data1/wyc/img.png', imgs)
                            tag = 1
                        '''

            loss_D = loss_adv_pred_ * 0.5 + bce_loss(
                D_out_2, make_D_label(gt_label, D_out_2))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()

        for optimizer_D in optimizer_DS:
            optimizer_D.step()

        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(args.num_steps) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(i_iter) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Exemplo n.º 4
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=16,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=16,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_gt_sampler_all = data.sampler.SubsetRandomSampler(train_ids)
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader_all = data.DataLoader(train_dataset,
                                          batch_size=args.batch_size,
                                          sampler=train_sampler_all,
                                          num_workers=16,
                                          pin_memory=True)
        trainloader_gt_all = data.DataLoader(train_gt_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_gt_sampler_all,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=16,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=16,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=16,
                                         pin_memory=True)

        trainloader_remain_iter = iter(trainloader_remain)

    trainloader_all_iter = iter(trainloader_all)
    trainloader_iter = iter(trainloader)
    trainloader_gt_iter = iter(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    #y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda())

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_fm_value = 0
        loss_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # train with source

            try:
                batch = next(trainloader_iter)
            except:
                trainloader_iter = iter(trainloader)
                batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            #ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)
            loss_seg.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size

            if i_iter >= args.adv_start:

                #fm loss calc
                try:
                    batch = next(trainloader_all_iter)
                except:
                    trainloader_iter = iter(trainloader_all)
                    batch = next(trainloader_all_iter)

                images, labels, _, _ = batch
                images = Variable(images).cuda(args.gpu)
                #ignore_mask = (labels.numpy() == 255)
                pred = interp(model(images))

                _, D_out_y_pred = model_D(F.softmax(pred))

                trainloader_gt_iter = iter(trainloader_gt)
                batch = next(trainloader_gt_iter)

                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                _, D_out_y_gt = model_D(D_gt_v)

                fm_loss = torch.mean(
                    torch.abs(
                        torch.mean(D_out_y_gt, 0) -
                        torch.mean(D_out_y_pred, 0)))

                loss = loss_seg + args.lambda_fm * fm_loss

                # proper normalization
                fm_loss.backward()
                #loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size
                loss_fm_value += fm_loss.data.cpu().numpy()[0] / args.iter_size
                loss_value += loss.data.cpu().numpy()[0] / args.iter_size

                # train D

                # bring back requires_grad
                for param in model_D.parameters():
                    param.requires_grad = True

                # train with pred
                pred = pred.detach()

                D_out_z, _ = model_D(F.softmax(pred))
                y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda())
                loss_D_fake = criterion(D_out_z, y_fake_)

                # train with gt
                # get gt labels
                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
                #ignore_mask_gt = (labels_gt.numpy() == 255)

                D_out_z_gt, _ = model_D(D_gt_v)
                #D_out = interp(D_out_x)

                y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda())

                loss_D_real = criterion(D_out_z_gt, y_real_)
                loss_D = loss_D_fake + loss_D_real
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_D = {3:.3f}'.
              format(i_iter, args.num_steps, loss_seg_value, loss_D_value))
        print('fm_loss: ', loss_fm_value, ' g_loss: ', loss_value)

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Exemplo n.º 5
0
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network

    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters (weights)
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(
            args.restore_from
        )  ## http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth
    else:
        saved_state_dict = torch.load(args.restore_from)
        #checkpoint = torch.load(args.restore_from)_

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()  # state_dict() is current model
    for name, param in new_params.items():
        #print (name) # 'conv1.weight, name:param(value), dict
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            #print('copy {}'.format(name))
    model.load_state_dict(new_params)
    #model.load_state_dict(checkpoint['state_dict'])
    #optimizer.load_state_dict(args.checkpoint['optim_dict'])

    model.train(
    )  # https://pytorch.org/docs/stable/nn.html, Sets the module in training mode.
    model.cuda(args.gpu)  ##

    cudnn.benchmark = True  # This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware

    # init D

    model_D = FCDiscriminator(num_classes=args.num_classes)
    #args.restore_from_D = 'snapshots/linear2/VOC_25000_D.pth'
    if args.restore_from_D is not None:  # None
        model_D.load_state_dict(torch.load(args.restore_from_D))
        # checkpoint_D = torch.load(args.restore_from_D)
        # model_D.load_state_dict(checkpoint_D['state_dict'])
        # optimizer_D.load_state_dict(checkpoint_D['optim_dict'])
    model_D.train()
    model_D.cuda(args.gpu)

    if USECALI:
        model_cali = ModelWithTemperature(model, model_D)
        model_cali.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)
    train_dataset_remain = VOCDataSet(args.data_dir,
                                      args.data_list_remain,
                                      crop_size=input_size,
                                      scale=args.random_scale,
                                      mirror=args.random_mirror,
                                      mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_dataset_size_remain = len(train_dataset_remain)

    print train_dataset_size
    print train_dataset_size_remain

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)
    if args.partial_data is None:  #if not partial, load all

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        #args.partial_data = 0.125
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:  #args.partial_id is none
            train_ids = range(train_dataset_size)
            train_ids_remain = range(train_dataset_size_remain)
            np.random.shuffle(train_ids)  #shuffle!
            np.random.shuffle(train_ids_remain)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'),
                         'wb'))  #randomly suffled ids

        #sampler
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:])  # 0~1/8,
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids_remain[:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:])
        # train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) # 0~1/8
        # train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) # used as unlabeled, 7/8
        # train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size])

        #train loader
        trainloader = data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            num_workers=3,
            pin_memory=True)  # multi-process data loading
        trainloader_remain = data.DataLoader(train_dataset_remain,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        # trainloader_remain = data.DataLoader(train_dataset,
        #                                      batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3,
        #                                     pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network

    # model.optim_paramters(args) = list(dict1, dict2), dict1 >> 'lr' and 'params'
    # print(type(model.optim_parameters(args)[0]['params'])) # generator
    #print(model.state_dict()['coeff'][0]) #confirmed

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    #optimizer.add_param_group({"params":model.coeff}) # assign new coefficient to the optimizer
    #print(len(optimizer.param_groups))
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()  #initialize

    if USECALI:
        optimizer_cali = optim.LBFGS([model_cali.temperature],
                                     lr=0.01,
                                     max_iter=50)
        optimizer_cali.zero_grad()

        nll_criterion = BCEWithLogitsLoss().cuda()  # BCE!!
        ece_criterion = ECELoss().cuda()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(
        size=(input_size[1], input_size[0]), mode='bilinear'
    )  # okay it automatically change to functional.interpolate
    # 321, 321

    if version.parse(torch.__version__) >= version.parse('0.4.0'):  #0.4.1
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1
    semi_ratio_sum = 0
    semi_sum = 0
    loss_seg_sum = 0
    loss_adv_sum = 0
    loss_vat_sum = 0
    l_seg_sum = 0
    l_vat_sum = 0
    l_adv_sum = 0
    logits_list = []
    labels_list = []

    #https: // towardsdatascience.com / understanding - pytorch -with-an - example - a - step - by - step - tutorial - 81fc5f8c4e8e

    for i_iter in range(args.num_steps):

        loss_seg_value = 0  # L_seg
        loss_adv_pred_value = 0  # 0.01 L_adv
        loss_D_value = 0  # L_D
        loss_semi_value = 0  # 0.1 L_semi
        loss_semi_adv_value = 0  # 0.001 L_adv
        loss_vat_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)  #changing lr by iteration

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            ###################### train G!!!###########################
            ############################################################
            # don't accumulate grads in D

            for param in model_D.parameters(
            ):  # <class 'torch.nn.parameter.Parameter'>, convolution weights
                param.requires_grad = False  # do not update gradient of D (freeze) while G

            ######### do unlabeled first!! 0.001 L_adv + 0.1 L_semi ###############

            # lambda_semi, lambda_adv for unlabeled
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.next(
                    )  #remain = unlabeled
                    print(trainloader_remain_iter.next())
                except:
                    trainloader_remain_iter = enumerate(
                        trainloader_remain)  # impose counters
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch  # <class 'torch.Tensor'>
                images = Variable(images).cuda(
                    args.gpu)  # <class 'torch.Tensor'>

                pred = interp(
                    model(images))  # S(X), pred <class 'torch.Tensor'>
                pred_remain = pred.detach(
                )  #use detach() when attempting to remove a tensor from a computation graph, will be used for D
                # https://discuss.pytorch.org/t/clone-and-detach-in-v0-4-0/16861
                # The difference is that detach refers to only a given variable on which it's called.
                # torch.no_grad affects all operations taking place within the with statement. >> for context,
                # requires_grad is for tensor

                # pred >> (8,21,321,321), L_adv
                D_out = interp(
                    model_D(F.softmax(pred))
                )  # D(S(X)), confidence, 8,1,321,321, not detached, there was not dim
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)  # (8,321,321) 0~1

                # 0.001 L_adv!!!!
                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)  # no ignore_mask for unlabeled adv
                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label,
                                        ignore_mask_remain))  #gt_label =1,
                # -log(D(S(X)))
                loss_semi_adv = loss_semi_adv / args.iter_size  #normalization

                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                ##--- visualization, pred(8,21,321,321), D_out_sigmoid(8,321,321)
                """
                if i_iter % 1000 == 0:
                    vpred = pred.transpose(1, 2).transpose(2, 3).contiguous()  # (8,321,321,21)
                    vpred = vpred.view(-1, 21)  # (8*321*321, 21)
                    vlogsx = F.log_softmax(vpred)  # torch.Tensor
                    vsemi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    vsemi_gt = Variable(torch.FloatTensor(vsemi_gt).long()).cuda(gpu)
                    vlogsx = vlogsx.gather(1, vsemi_gt.view(-1, 1))
                    sx = F.softmax(vpred).gather(1, vsemi_gt.view(-1, 1))
                    vD_out_sigmoid = Variable(torch.FloatTensor(D_out_sigmoid)).cuda(gpu).view(-1, 1)
                    vlogsx = (vlogsx*(2.5*vD_out_sigmoid+0.5))
                    vlogsx = -vlogsx.squeeze(dim=1)
                    sx = sx.squeeze(dim=1)
                    vD_out_sigmoid = vD_out_sigmoid.squeeze(dim=1)
                    dsx = vD_out_sigmoid.data.cpu().detach().numpy()
                    vlogsx = vlogsx.data.cpu().detach().numpy()
                    sx = sx.data.cpu().detach().numpy()
                    plt.clf()
                    plt.figure(figsize=(15, 5))
                    plt.subplot(131)
                    plt.ylim(0, 0.004)
                    plt.scatter(dsx, vlogsx, s = 0.1)  # variable requires grad cannot call numpy >> detach
                    plt.xlabel('D(S(X))')
                    plt.ylabel('Loss_Semi per Pixel')
                    plt.subplot(132)
                    plt.scatter(dsx, vlogsx, s = 0.1)  # variable requires grad cannot call numpy >> detach
                    plt.xlabel('D(S(X))')
                    plt.ylabel('Loss_Semi per Pixel')
                    plt.subplot(133)
                    plt.scatter(dsx, sx, s=0.1)
                    plt.xlabel('D(S(X))')
                    plt.ylabel('S(x)')
                    plt.savefig('/home/eungyo/AdvSemiSeg/plot/'  + str(i_iter) + '.png')
                    """

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:

                    semi_gt = pred.data.cpu().numpy().argmax(
                        axis=1
                    )  # pred=S(X) ((8,21,321,321)), semi_gt is not one-hot, 8,321,321
                    #(8, 321, 321)

                    if not USECALI:
                        semi_ignore_mask = (
                            D_out_sigmoid < args.mask_T
                        )  # both (8,321,321) 0~1threshold!, numpy
                        semi_gt[
                            semi_ignore_mask] = 255  # Yhat, ignore pixel becomes 255
                        semi_ratio = 1.0 - float(semi_ignore_mask.sum(
                        )) / semi_ignore_mask.size  # ignored pixels / H*W
                        print('semi ratio: {:.4f}'.format(semi_ratio))

                        if semi_ratio == 0.0:
                            loss_semi_value += 0
                        else:
                            semi_gt = torch.FloatTensor(semi_gt)
                            confidence = torch.FloatTensor(
                                D_out_sigmoid)  ## added, only pred is on cuda
                            loss_semi = args.lambda_semi * weighted_loss_calc(
                                pred, semi_gt, args.gpu, confidence)

                    else:
                        semi_ratio = 1
                        semi_gt = (torch.FloatTensor(semi_gt))  # (8,321,321)
                        confidence = torch.FloatTensor(
                            F.sigmoid(
                                model_cali.temperature_scale(D_out.view(
                                    -1))).data.cpu().numpy())  # (8*321*321,)
                        loss_semi = args.lambda_semi * calibrated_loss_calc(
                            pred, semi_gt, args.gpu, confidence, accuracies,
                            n_bin
                        )  #  L_semi = Yhat * log(S(X)) # loss_calc(pred, semi_gt, args.gpu)
                        # pred(8,21,321,321)

                    if semi_ratio != 0:
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi

                        if args.method == 'vatent' or args.method == 'vat':
                            #v_loss = vat_loss(model, images, pred, eps=args.epsilon[i])  # R_vadv
                            weighted_v_loss = weighted_vat_loss(
                                model,
                                images,
                                pred,
                                confidence,
                                eps=args.epsilon)

                            if args.method == 'vatent':
                                #v_loss += entropy_loss(pred)  # R_cent (conditional entropy loss)
                                weighted_v_loss += weighted_entropy_loss(
                                    pred, confidence)

                            v_loss = weighted_v_loss / args.iter_size
                            loss_vat_value += v_loss.data.cpu().numpy()
                            loss_semi_adv += args.alpha * v_loss

                            loss_vat_sum += loss_vat_value
                            if i_iter % 100 == 0 and sub_i == 4:
                                l_vat_sum = loss_vat_sum / 100
                                if i_iter == 0:
                                    l_vat_sum = l_vat_sum * 100
                                loss_vat_sum = 0

                        loss_semi += loss_semi_adv
                        loss_semi.backward(
                        )  # 0.001 L_adv + 0.1 L_semi, backward == back propagation

            else:
                loss_semi = None
                loss_semi_adv = None

            ###########train with source (labeled data)############### L_ce + 0.01 * L_adv

            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)  # safe coding
                _, batch = trainloader_iter.next()  #counter, batch

            images, labels, _, _ = batch  # also get labels images(8,321,321)
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (
                labels.numpy() == 255
            )  # ignored pixels == 255 >> 1, yes ignored mask for labeled data

            pred = interp(model(images))  # S(X), 8,21,321,321
            loss_seg = loss_calc(pred, labels,
                                 args.gpu)  # -Y*logS(X)= L_ce, not detached

            if USED:
                softsx = F.softmax(pred, dim=1)
                D_out = interp(model_D(softsx))  # D(S(X)), L_adv

                loss_adv_pred = bce_loss(
                    D_out, make_D_label(
                        gt_label,
                        ignore_mask))  # both  8,1,321,321, gt_label = 1
                # L_adv =  -log(D(S(X)), make_D_label is all 1 except ignored_region

                loss = loss_seg + args.lambda_adv_pred * loss_adv_pred
                if USECALI:
                    if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                        ) and i_iter >= args.semi_start_adv:
                        with torch.no_grad():
                            _, prediction = torch.max(softsx, 1)
                            labels_mask = (
                                (labels > 0) *
                                (labels != 255)) | (prediction.data.cpu() > 0)
                            labels = labels[labels_mask]
                            prediction = prediction[labels_mask]
                            fake_mask = (labels.data.cpu().numpy() !=
                                         prediction.data.cpu().numpy())
                            real_label = make_conf_label(
                                1, fake_mask
                            )  # (10*321*321, ) 0 or 1 (fake or real)

                            logits = D_out.squeeze(dim=1)
                            logits = logits[labels_mask]
                            logits_list.append(logits)  # initialize
                            labels_list.append(real_label)

                        if (i_iter * args.iter_size * args.batch_size + sub_i +
                                1) % train_dataset_size == 0:
                            logits = torch.cat(logits_list).cuda(
                            )  # overall 5000 images in val,  #logits >> 5000,100, (1464*321*321,)
                            labels = torch.cat(labels_list).cuda()
                            before_temperature_nll = nll_criterion(
                                logits, labels).item()  ####modify
                            before_temperature_ece, _, _ = ece_criterion(
                                logits, labels)  # (1464*321*321,)
                            before_temperature_ece = before_temperature_ece.item(
                            )
                            print('Before temperature - NLL: %.3f, ECE: %.3f' %
                                  (before_temperature_nll,
                                   before_temperature_ece))

                            def eval():
                                loss_cali = nll_criterion(
                                    model_cali.temperature_scale(logits),
                                    labels)
                                loss_cali.backward()
                                return loss_cali

                            optimizer_cali.step(
                                eval)  # just one backward >> not 50 iterations
                            after_temperature_nll = nll_criterion(
                                model_cali.temperature_scale(logits),
                                labels).item()
                            after_temperature_ece, accuracies, n_bin = ece_criterion(
                                model_cali.temperature_scale(logits), labels)
                            after_temperature_ece = after_temperature_ece.item(
                            )
                            print('Optimal temperature: %.3f' %
                                  model_cali.temperature.item())
                            print(
                                'After temperature - NLL: %.3f, ECE: %.3f' %
                                (after_temperature_nll, after_temperature_ece))

                            logits_list = []
                            labels_list = []

            else:
                loss = loss_seg

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_sum += loss_seg / args.iter_size
            if USED:
                loss_adv_sum += loss_adv_pred
            if i_iter % 100 == 0 and sub_i == 4:
                l_seg_sum = loss_seg_sum / 100
                if USED:
                    l_adv_sum = loss_adv_sum / 100
                if i_iter == 0:
                    l_seg_sum = l_seg_sum * 100
                    l_adv_sum = l_adv_sum * 100
                loss_seg_sum = 0
                loss_adv_sum = 0

            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            if USED:
                loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
                ) / args.iter_size

            ##################### train D!!!###########################
            ###########################################################
            # bring back requires_grad
            if USED:
                for param in model_D.parameters():
                    param.requires_grad = True  # before False.

                ############# train with pred S(X)############# labeled + unlabeled
                pred = pred.detach(
                )  #orginally only use labeled data, freeze S(X) when train D,

                # We do train D with the unlabeled data. But the difference is quite small

                if args.D_remain:  #default true
                    pred = torch.cat(
                        (pred, pred_remain), 0
                    )  # pred_remain(unlabeled S(x)) is detached  16,21,321,321
                    ignore_mask = np.concatenate(
                        (ignore_mask, ignore_mask_remain),
                        axis=0)  # 16,321,321

                D_out = interp(
                    model_D(F.softmax(pred, dim=1))
                )  # D(S(X)) 16,1,321,321  # softmax(pred,dim=1) for 0.4, not nessesary

                loss_D = bce_loss(D_out,
                                  make_D_label(pred_label,
                                               ignore_mask))  # pred_label = 0

                # -log(1-D(S(X)))
                loss_D = loss_D / args.iter_size / 2  # iter_size = 1, /2 because there is G and D
                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()

                ################## train with gt################### only labeled
                #VOCGT and VOCdataset can be reduced to one dataset in this repo.
                # get gt labels Y
                #print "before train gt"
                try:
                    print(trainloader_gt_iter.next())  # len 732
                    _, batch = trainloader_gt_iter.next()

                except:
                    trainloader_gt_iter = enumerate(trainloader_gt)
                    _, batch = trainloader_gt_iter.next()
                #print "train with gt?"
                _, labels_gt, _, _ = batch
                D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)  #one_hot
                ignore_mask_gt = (labels_gt.numpy() == 255
                                  )  # same as ignore_mask (8,321,321)
                #print "finish"
                D_out = interp(model_D(D_gt_v))  # D(Y)
                loss_D = bce_loss(D_out,
                                  make_D_label(gt_label,
                                               ignore_mask_gt))  # log(D(Y))
                loss_D = loss_D / args.iter_size / 2

                loss_D.backward()
                loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()

        if USED:
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))  #snapshot
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.6f}, loss_semi_adv = {6:.3f}, loss_vat = {7: .5f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value, loss_vat_value))
        #                                        L_ce             L_adv for labeled      L_D                L_semi              L_adv for unlabeled
        #loss_adv should be inversely proportional to the loss_D if they are seeing the same data.
        # loss_adv_p is essentially the inverse loss of loss_D. We expect them to achieve a good balance during the adversarial training
        # loss_D is around 0.2-0.5   >> good
        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar'))
            #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar'))
            break

        if i_iter % 100 == 0 and sub_i == 4:  #loss_seg_value
            wdata = "iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.6f}, loss_semi = {5:.8f}, loss_semi_adv = {6:.3f}, l_vat_sum = {7: .5f}, loss_label = {8: .4}\n".format(
                i_iter, args.num_steps, l_seg_sum, l_adv_sum, loss_D_value,
                loss_semi_value, loss_semi_adv_value, l_vat_sum,
                l_seg_sum + 0.01 * l_adv_sum)
            #wdata2 = "{0:8d} {1:s} {2:s} {3:s} {4:s} {5:s} {6:s} {7:s} {8:s}\n".format(i_iter,str(model.coeff[0])[8:14],str(model.coeff[1])[8:14],str(model.coeff[2])[8:14],str(model.coeff[3])[8:14],str(model.coeff[4])[8:14],str(model.coeff[5])[8:14],str(model.coeff[6])[8:14],str(model.coeff[7])[8:14])
            if i_iter == 0:
                f2 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'w')
                f2.write(wdata)
                f2.close()
                #f3 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'w')
                #f3.write(wdata2)
                #f3.close()
            else:
                f1 = open("/home/eungyo/AdvSemiSeg/snapshots/log.txt", 'a')
                f1.write(wdata)
                f1.close()
                #f4 = open("/home/eungyo/AdvSemiSeg/snapshots/coeff.txt", 'a')
                #f4.write(wdata2)
                #f4.close()

        if i_iter % args.save_pred_every == 0 and i_iter != 0:  # 5000
            print('taking snapshot ...')
            #state = {'epoch':i_iter, 'state_dict':model.state_dict(),'optim_dict':optimizer.state_dict()}
            #state_D = {'epoch':i_iter, 'state_dict': model_D.state_dict(), 'optim_dict': optimizer_D.state_dict()}
            #torch.save(state, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth.tar'))
            #torch.save(state_D, osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Exemplo n.º 6
0
def main():
    # prepare
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D = detector.FlawDetector(in_channels=24)
    # model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = [_ for _ in range(0, train_dataset_size)]
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    minimum_loss = detector.MinimumCriterion()
    detector_loss = detector.FlawDetectorCriterion()

    # bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        if i_iter > 0 and i_iter % 1000 == 0:
            val(model, args.gpu)
            model.train()

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                pred_remain = pred.detach()

                D_out = model_D(images, pred)

                ignore_mask_remain = np.zeros(D_out.shape).astype(np.bool)
                loss_semi_adv = args.lambda_semi_adv * minimum_loss(
                    D_out)  # ke: SSL loss for unlabeled data
                loss_semi_adv = loss_semi_adv

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.iter_size

                loss_semi_adv.backward()
                loss_semi_value = 0

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source (labeled data)

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(images, pred))

            loss_adv_pred = args.lambda_adv_pred * minimum_loss(
                D_out)  # ke: SSL loss for labeled data

            loss = loss_seg + loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred from labeled data
            pred = pred.detach()

            D_out = interp(model_D(images, pred))

            detect_gt = detector.generate_flaw_detector_gt(
                pred,
                labels.view(labels.shape[0], 1, labels.shape[1],
                            labels.shape[2]).cuda(args.gpu), NUM_CLASSES,
                IGNORE_LABEL)
            loss_D = detector_loss(D_out, detect_gt)

            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # # train with gt
            # # get gt labels
            # try:
            #     _, batch = trainloader_gt_iter.__next__()
            # except:
            #     trainloader_gt_iter = enumerate(trainloader_gt)
            #     _, batch = trainloader_gt_iter.__next__()

            # _, labels_gt, _, _ = batch
            # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            # ignore_mask_gt = (labels_gt.numpy() == 255)

            # D_out = interp(model_D(D_gt_v))
            # loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            # loss_D = loss_D/args.iter_size/2
            # loss_D.backward()
            # loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.6f}, loss_adv_l = {3:.6f}, loss_D = {4:.6f}, loss_semi = {5:.6f}, loss_adv_u = {6:.6f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
def main():
    tag = 0

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)
    model = nn.DataParallel(model)
    model.cuda()

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)

    saved_state_dict = torch.load(
        '/data1/wyc/AdvSemiSeg/snapshots/VOC_t_baseline_1adv_mul_20000.pth')

    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
        else:
            print 123456
    model.load_state_dict(new_params)

    model.train()

    cudnn.benchmark = True

    # init D
    model_D = Discriminator2(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D = nn.DataParallel(model_D)
    model_D.train()
    model_D.cuda()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data == 0:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = torch.nn.BCELoss()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # train with source
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))
            loss_seg = loss_calc(pred, labels)

            pred01 = F.softmax(pred, dim=1)

            pred_re = F.softmax(pred, dim=1).repeat(1, 3, 1, 1)

            indices_1 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([0])).cuda())
            indices_2 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([1])).cuda())
            indices_3 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([2])).cuda())
            img_re = torch.cat([
                indices_1.repeat(1, 21, 1, 1),
                indices_2.repeat(1, 21, 1, 1),
                indices_3.repeat(1, 21, 1, 1),
            ], 1)

            mul_img = pred_re * img_re

            for i_l in range(labels.shape[0]):
                label_set = np.unique(labels[i_l]).tolist()
                for ls in label_set:
                    if ls != 0 and ls != 255:
                        ls = int(ls)

                        img_p = torch.cat([
                            mul_img[i_l][ls].unsqueeze(0).unsqueeze(0),
                            mul_img[i_l][ls + 21].unsqueeze(0).unsqueeze(0),
                            mul_img[i_l][ls + 21 +
                                         21].unsqueeze(0).unsqueeze(0)
                        ], 1)

                        imgs = img_p.squeeze()
                        imgs = imgs.transpose(0, 1)
                        imgs = imgs.transpose(1, 2)
                        imgs = imgs.data.cpu().numpy()

                        img_ori = images[0]
                        img_ori = img_ori.squeeze()
                        img_ori = img_ori.transpose(0, 1)
                        img_ori = img_ori.transpose(1, 2)
                        img_ori = img_ori.data.cpu().numpy()

                        pred_ori = pred01[0][ls]
                        pred_ori = pred_ori.data.cpu().numpy()
                        pred_ori = pred_ori[:, :, np.newaxis] * 255.0

                        # print pred_ori.shape

                        if tag == 0:
                            print ls
                            cv2.imwrite('/data1/wyc/1.png', imgs)
                            cv2.imwrite('/data1/wyc/2.png', img_ori)
                            cv2.imwrite('/data1/wyc/3.png', pred_ori)
                            tag = 1

            D_out = model_D(mul_img)

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, D_out))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            pred_re2 = F.softmax(pred, dim=1).repeat(1, 3, 1, 1)

            mul_img2 = pred_re2 * img_re

            D_out = model_D(mul_img2)

            loss_D = bce_loss(D_out, make_D_label(pred_label, D_out))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            img_gt, labels_gt, _, _ = batch
            img_gt = Variable(img_gt).cuda()
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            pred_re3 = D_gt_v.repeat(1, 3, 1, 1)
            indices_1 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([0])).cuda())
            indices_2 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([1])).cuda())
            indices_3 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([2])).cuda())
            img_re3 = torch.cat([
                indices_1.repeat(1, 21, 1, 1),
                indices_2.repeat(1, 21, 1, 1),
                indices_3.repeat(1, 21, 1, 1),
            ], 1)

            mul_img3 = pred_re3 * img_re3

            D_out = model_D(mul_img3)

            loss_D = bce_loss(D_out, make_D_label(gt_label, D_out))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(args.num_steps) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % 100 == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(i_iter) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Exemplo n.º 8
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

        # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    # only copy the params that exist in current model (caffe-like)

    # state_dict = torch.load(
    #     '/data1/wyc/AdvSemiSeg/snapshots/VOC_t_baseline_1adv_mul_20000.pth')  # baseline707 adv 709 nadv 705()*2
    # from collections import OrderedDict
    # new_state_dict = OrderedDict()
    # for k, v in state_dict.items():
    #     name = k[7:]  # remove `module.`
    #     new_state_dict[name] = v
    # # load params
    #
    # new_params = model.state_dict().copy()
    # for name, param in new_params.items():
    #     print (name)
    #     if name in new_state_dict and param.size() == new_state_dict[name].size():
    #         new_params[name].copy_(new_state_dict[name])
    #         print('copy {}'.format(name))
    #
    # model.load_state_dict(new_params)

    model.train()
    model = nn.DataParallel(model)
    model.cuda()

    cudnn.benchmark = True

    # init D
    model_D = Discriminator2(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D = nn.DataParallel(model_D)
    model_D.train()
    model_D.cuda()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data == 0:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = torch.nn.BCELoss()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        tw = [
            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
            20
        ]

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # train with source
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))
            loss_seg = loss_calc(pred, labels)

            pred_0 = F.softmax(pred, dim=1)
            #pred_0 = 1 / (math.e ** (((pred_01 - 0.33) * 30) * (-1)) + 1)

            labels0 = Variable(one_hot(labels)).cuda()
            one_s = Variable(torch.ones(labels0.size())).cuda()
            labels0 = one_s - labels0
            labels0 = torch.index_select(
                labels0, 1,
                Variable(
                    torch.LongTensor([
                        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
                        17, 18, 19, 20
                    ])).cuda())
            pred0 = torch.index_select(
                pred_0, 1,
                Variable(
                    torch.LongTensor([
                        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
                        17, 18, 19, 20
                    ])).cuda())

            pred_label0 = labels0 * pred0
            pred_max = torch.max(pred_0, dim=1)[1]

            pred_max = Variable(one_hot(pred_max.cpu().data)).cuda()

            pred_max = torch.index_select(
                pred_max, 1,
                Variable(
                    torch.LongTensor([
                        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
                        17, 18, 19, 20
                    ])).cuda())

            pred_c = pred_label0 * (pred_max)
            one_s = Variable(torch.ones(pred_max.size())).cuda()
            pred_m = pred_label0 * (one_s - pred_max)

            c3 = 0
            c4 = 0
            c5 = 0
            c6 = 0
            c7 = 0
            c8 = 0
            c9 = 0
            c10 = 0
            c0 = 0

            pred_min_c_list = pred_c.cpu().data.numpy().flatten().tolist()
            pred_min_c_l = len(pred_min_c_list)

            for n in pred_min_c_list:
                if n < 0.00000001:
                    c0 = c0 + 1
                elif n < 0.1:
                    c3 = c3 + 1
                elif n < 0.2:
                    c4 = c4 + 1
                elif n < 0.3:
                    c5 = c5 + 1
                elif n < 0.4:
                    c6 = c6 + 1
                elif n < 0.5:
                    c7 = c7 + 1
                elif n < 0.6:
                    c8 = c8 + 1
                elif n < 0.7:
                    c9 = c9 + 1
                elif n <= 1:
                    c10 = c10 + 1
                else:
                    print n

            if pred_min_c_l - c0 == 0:
                print pred_min_c_l
            else:
                pred_min_c_l = (pred_min_c_l - c0) * 1.00000
                print "correct", 3, ":", c3 / pred_min_c_l, 4, ":", c4 / pred_min_c_l, 5, ":", c5 / pred_min_c_l, 6, ":", c6 / pred_min_c_l, 7, ":", c7 / pred_min_c_l, 8, ":", c8 / pred_min_c_l, 9, ":", c9 / pred_min_c_l, 10, ":", c10 / pred_min_c_l

            c3 = 0
            c4 = 0
            c5 = 0
            c6 = 0
            c7 = 0
            c8 = 0
            c9 = 0
            c10 = 0
            c0 = 0

            pred_min_m_list = pred_m.cpu().data.numpy().flatten().tolist()
            pred_min_c_l = len(pred_min_m_list)

            for n in pred_min_m_list:
                if n < 0.0000001:
                    c0 = c0 + 1
                elif n < 0.005:
                    c3 = c3 + 1
                elif n < 0.05:
                    c4 = c4 + 1
                elif n < 0.3:
                    c5 = c5 + 1
                elif n < 0.4:
                    c6 = c6 + 1
                elif n < 0.5:
                    c7 = c7 + 1
                else:
                    print n
            if pred_min_c_l - c0 == 0:
                print pred_min_c_l
            else:
                pred_min_c_l = (pred_min_c_l - c0) * 1.00000
                print "mistake", 1, ":", c3 / pred_min_c_l, 2, ":", c4 / pred_min_c_l, 3, ":", c5 / pred_min_c_l, 4, ":", c6 / pred_min_c_l, 5, ":", c7 / pred_min_c_l
            '''

            images, labels, _, _ = batch
            images = Variable(images).cuda()
            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))
            loss_seg = loss_calc(pred, labels)

            pred_0 = F.softmax(pred, dim=1)

            labels0 = Variable(one_hot(labels)).cuda()
            labels0=torch.index_select(labels0, 1, Variable(torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])).cuda())
            pred0 = torch.index_select(pred_0, 1, Variable(torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])).cuda())


            pred_label0=labels0*pred0
            pred_max = torch.max(pred_0, dim=1)[1]


            pred_max = Variable(one_hot(pred_max.cpu().data)).cuda()

            pred_max = torch.index_select(pred_max, 1, Variable(
                torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])).cuda())


            pred_c=pred_label0*(pred_max)
            one_s = Variable(torch.ones(pred_max.size())).cuda()
            pred_m = pred_label0 * (one_s - pred_max)

            c3 = 0
            c4 = 0
            c5 = 0
            c6 = 0
            c7 = 0
            c8 = 0
            c9 = 0
            c10 = 0
            c0 = 0

            pred_min_c_list = pred_c.cpu().data.numpy().flatten().tolist()
            pred_min_c_l = len(pred_min_c_list)

            for n in pred_min_c_list:
                if n < 0.00000001:
                    c0 = c0 + 1
                elif n < 0.3:
                    c3 = c3 + 1
                elif n < 0.4:
                    c4 = c4 + 1
                elif n < 0.5:
                    c5 = c5 + 1
                elif n < 0.6:
                    c6 = c6 + 1
                elif n < 0.7:
                    c7 = c7 + 1
                elif n < 0.8:
                    c8 = c8 + 1
                elif n < 0.9:
                    c9 = c9 + 1
                elif n <= 1:
                    c10 = c10 + 1
                else:
                    print n

            if pred_min_c_l - c0 == 0:
                print pred_min_c_l
            else:
                pred_min_c_l = (pred_min_c_l - c0) * 1.00000
                print "correct", 3, ":", c3 / pred_min_c_l, 4, ":", c4 / pred_min_c_l, 5, ":", c5 / pred_min_c_l, 6, ":", c6 / pred_min_c_l, 7, ":", c7 / pred_min_c_l, 8, ":", c8 / pred_min_c_l, 9, ":", c9 / pred_min_c_l, 10, ":", c10 / pred_min_c_l

            c3 = 0
            c4 = 0
            c5 = 0
            c6 = 0
            c7 = 0
            c8 = 0
            c9 = 0
            c10 = 0
            c0 = 0

            pred_min_m_list = pred_m.cpu().data.numpy().flatten().tolist()
            pred_min_c_l = len(pred_min_m_list)


            for n in pred_min_m_list:
                if n < 0.0000001:
                    c0 = c0 + 1
                elif n < 0.1:
                    c3 = c3 + 1
                elif n < 0.2:
                    c4 = c4 + 1
                elif n < 0.3:
                    c5 = c5 + 1
                elif n < 0.4:
                    c6 = c6 + 1
                elif n < 0.5:
                    c7 = c7 + 1
                else:
                    print n
            if pred_min_c_l - c0 == 0:
                print pred_min_c_l
            else:
                pred_min_c_l = (pred_min_c_l - c0) * 1.00000
                print "mistake", 1, ":", c3 / pred_min_c_l, 2, ":", c4 / pred_min_c_l, 3, ":", c5 / pred_min_c_l, 4, ":", c6 / pred_min_c_l, 5, ":", c7 / pred_min_c_l
            '''

            # c3=0
            # c4=0
            # c5=0
            # c6=0
            # c7=0
            # c8=0
            # c9=0
            # c10=0
            # c0=0
            #
            # pred_min_c_list=pred_c.cpu().data.numpy().flatten().tolist()
            #
            # pred_min_c=set(pred_c.cpu().data.numpy().flatten().tolist())
            # pred_min_c_l=len(pred_min_c_list)
            # if len(pred_min_c) >1:
            #     pred_min_c.discard(0.0)
            #     pred_min_c=list(pred_min_c)
            #     pred_min_c2 = min(pred_min_c)
            #     pred_max_c2 = max(pred_min_c)
            # else:
            #     pred_min_c2 = 999
            #     pred_max_c2 = 999
            # for n in pred_min_c_list:
            #     if n<0.00000001:
            #         c0=c0+1
            #     elif n<0.3:
            #         c3=c3+1
            #     elif n<0.4:
            #         c4=c4+1
            #     elif n<0.5:
            #         c5=c5+1
            #     elif n<0.6:
            #         c6=c6+1
            #     elif n<0.7:
            #         c7=c7+1
            #     elif n<0.8:
            #         c8=c8+1
            #     elif n<0.9:
            #         c9=c9+1
            #     elif n<=1:
            #         c10=c10+1
            #     else:
            #         print n
            #
            # if pred_min_c_l-c0==0:
            #     print pred_min_c_l
            # else:
            #     pred_min_c_l=(pred_min_c_l-c0)*1.00000
            #
            # #print c0 + c3 + c4 + c5 + c6 + c7+c8+c9+c10
            #
            #     print "correct",3,":",c3/pred_min_c_l,4,":",c4/pred_min_c_l,5,":",c5/pred_min_c_l,6,":",c6/pred_min_c_l,7,":",c7/pred_min_c_l,8,":",c8/pred_min_c_l,9,":",c9/pred_min_c_l,10,":",c10/pred_min_c_l
            #
            # c3 = 0
            # c4 = 0
            # c5 = 0
            # c6 = 0
            # c7 = 0
            # c8 = 0
            # c9 = 0
            # c10 = 0
            # c0 = 0
            #
            #
            #
            #
            # pred_min_m_list = pred_m.cpu().data.numpy().flatten().tolist()
            # pred_min_m = set(pred_m.cpu().data.numpy().flatten().tolist())
            # pred_min_c_l = len(pred_min_m_list)
            # if len(pred_min_m) >1:
            #     pred_min_m.discard(0.0)
            #     pred_min_m=list(pred_min_m)
            #     pred_min_m2 = min(pred_min_m)
            #     pred_max_m2 = max(pred_min_m)
            # else:
            #     pred_min_m2 = 999
            #     pred_max_m2 = 999
            #
            # for n in pred_min_m_list:
            #     if n<0.0000001:
            #         c0=c0+1
            #     elif n<0.1:
            #         c3=c3+1
            #     elif n<0.2:
            #         c4=c4+1
            #     elif n<0.3:
            #         c5=c5+1
            #     elif n<0.4:
            #         c6=c6+1
            #     elif n<0.5:
            #         c7=c7+1
            #     else:
            #         print n
            # if pred_min_c_l-c0==0:
            #     print pred_min_c_l
            # else:
            #     pred_min_c_l=(pred_min_c_l-c0)*1.00000
            #     print "mistake",1,":",c3/pred_min_c_l,2,":",c4/pred_min_c_l,3,":",c5/pred_min_c_l,4,":",c6/pred_min_c_l,5,":",c7/pred_min_c_l
            #
            #
            #
            # print('max c  {}  min c  {}  max m  {}  min m  {}'.format(
            #                         pred_max_c2,pred_min_c2, pred_max_m2,pred_min_m2))
            '''20000
correct 3 : 6.60318801918e-05 4 : 0.00186540061542 5 : 0.00659328323715 6 : 0.0196708971091 7 : 0.0234446190621 8 : 0.0315071116335 9 : 0.0565364958202 10 : 0.860316160642
mistake 1 : 0.326503117461 2 : 0.211204060532 3 : 0.199110192061 4 : 0.15803490303 5 : 0.105147726917
max c  1.0  min c  0.206159025431  max m  0.499729216099  min m  3.0866687678e-11
iter =        0/   20000, loss_seg = 0.164, loss_adv_p = 0.688, loss_D = 0.687, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 1.03566126972e-05 4 : 0.000749128318431 5 : 0.00340042116892 6 : 0.0126937549625 7 : 0.0161735768288 8 : 0.0261400904478 9 : 0.0492767632133 10 : 0.891555908448
mistake 1 : 0.398261661669 2 : 0.20677876039 3 : 0.167001935704 4 : 0.129388545185 5 : 0.0985690970509
max c  1.0  min c  0.295039653778  max m  0.499661713839  min m  1.87631424287e-10
iter =        1/   20000, loss_seg = 0.114, loss_adv_p = 0.621, loss_D = 0.666, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.000170320224732 4 : 0.00122630561807 5 : 0.00573184329631 6 : 0.0155854360311 7 : 0.0216533779042 8 : 0.0339187050213 9 : 0.0683937894433 10 : 0.853320222461
mistake 1 : 0.370818679971 2 : 0.185940950356 3 : 0.165003680379 4 : 0.165167252801 5 : 0.113069436493
max c  1.0  min c  0.205323472619  max m  0.499007672071  min m  2.63143761003e-07
iter =        2/   20000, loss_seg = 0.136, loss_adv_p = 6.856, loss_D = 2.909, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 5.57311529183e-05 4 : 0.00162151116348 5 : 0.00529976725609 6 : 0.0104801106131 7 : 0.0125395094066 8 : 0.0183382031746 9 : 0.032151567505 10 : 0.919513599728
mistake 1 : 0.599129542479 2 : 0.138565514313 3 : 0.114044089027 4 : 0.0929903400446 5 : 0.0552705141361
max c  1.0  min c  0.251725673676  max m  0.499345749617  min m  2.17372370104e-11
iter =        3/   20000, loss_seg = 0.173, loss_adv_p = 0.211, loss_D = 1.186, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0 4 : 0.000408403449911 5 : 0.00314207170335 6 : 0.00972922412127 7 : 0.0128537300848 8 : 0.0230989478122 9 : 0.0531056227933 10 : 0.897662000035
mistake 1 : 0.376961004034 2 : 0.197477108279 3 : 0.160338093104 4 : 0.144009732983 5 : 0.1212140616
max c  1.0  min c  0.303397655487  max m  0.49908259511  min m  2.31815082538e-13

            '''
            '''0000
            
mistake 1 : 1.0 2 : 0.0 3 : 0.0 4 : 0.0 5 : 0.0
iter =        0/   20000, loss_seg = 3.045, loss_adv_p = 0.677, loss_D = 0.689, loss_semi = 0.000, loss_semi_adv = 0.000
20608200
mistake 1 : 1.0 2 : 0.0 3 : 0.0 4 : 0.0 5 : 0.0
iter =        1/   20000, loss_seg = 2.311, loss_adv_p = 0.001, loss_D = 4.009, loss_semi = 0.000, loss_semi_adv = 0.000
20608200
mistake 1 : 1.0 2 : 0.0 3 : 0.0 4 : 0.0 5 : 0.0
iter =        2/   20000, loss_seg = 2.695, loss_adv_p = 1.304, loss_D = 0.996, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.215316600114 4 : 0.28849591177 5 : 0.159402928313 6 : 0.113015782468 7 : 0.102595550485 8 : 0.0921563034797 9 : 0.0290169233695 10 : 0.0
mistake 1 : 0.80735288413 2 : 0.095622890725 3 : 0.0764600062066 4 : 0.0203387447147 5 : 0.000225474223205
iter =        3/   20000, loss_seg = 2.103, loss_adv_p = 0.936, loss_D = 0.720, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0 4 : 0.0106591946386 5 : 0.0159743486048 6 : 0.0144144664625 7 : 0.0482119128777 8 : 0.1030099948 9 : 0.149690912242 10 : 0.658039170374
mistake 1 : 0.997786787186 2 : 0.00220988967167 3 : 0.0 4 : 3.32314236342e-06 5 : 0.0
iter =        4/   20000, loss_seg = 3.231, loss_adv_p = 0.352, loss_D = 0.740, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0547200197126 4 : 0.0824408402489 5 : 0.105990337314 6 : 0.0824672410303 7 : 0.0818248220147 8 : 0.106500752422 9 : 0.12377566376 10 : 0.362280323498
mistake 1 : 0.89707796374 2 : 0.0776956426097 3 : 0.0212782533776 4 : 0.0038358600209 5 : 0.000112280251508
iter =        5/   20000, loss_seg = 2.440, loss_adv_p = 0.285, loss_D = 0.796, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0645113648539 4 : 0.11994706344 5 : 0.133270975533 6 : 0.130346143711 7 : 0.130461978634 8 : 0.155641595163 9 : 0.129387609717 10 : 0.136433268948
mistake 1 : 0.687583743536 2 : 0.172921294404 3 : 0.108947452597 4 : 0.027046718675 5 : 0.00350079078777
iter =        6/   20000, loss_seg = 1.562, loss_adv_p = 0.471, loss_D = 0.700, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0517260346307 4 : 0.0852126465864 5 : 0.0921211854525 6 : 0.0708675276672 7 : 0.0477657257266 8 : 0.0341796660139 9 : 0.0336846274009 10 : 0.584442586522
mistake 1 : 0.896840471376 2 : 0.0747503216715 3 : 0.0228403760663 4 : 0.00484719754372 5 : 0.000721633342184
iter =        7/   20000, loss_seg = 1.521, loss_adv_p = 1.028, loss_D = 0.781, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0354795814142 4 : 0.146293048623 5 : 0.182080980555 6 : 0.160111006186 7 : 0.0642910828885 8 : 0.0516294397657 9 : 0.0560234346393 10 : 0.304091425928
mistake 1 : 0.927571328438 2 : 0.0370632102681 3 : 0.0252651475979 4 : 0.00991933537868 5 : 0.000180978317074
iter =        8/   20000, loss_seg = 2.240, loss_adv_p = 1.128, loss_D = 1.030, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.013791534851 4 : 0.0218762276947 5 : 0.0420707290008 6 : 0.0509204695049 7 : 0.0551454624403 8 : 0.0738167607469 9 : 0.109429384722 10 : 0.632949431039
mistake 1 : 0.82698518928 2 : 0.0871808088324 3 : 0.056324796628 4 : 0.0222555690883 5 : 0.00725363617091
iter =        9/   20000, loss_seg = 1.620, loss_adv_p = 0.856, loss_D = 0.680, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0111066673332 4 : 0.0219300438268 5 : 0.039156626506 6 : 0.0714017897315 7 : 0.0741513772934 8 : 0.0834458164609 9 : 0.100755719975 10 : 0.598051958873
mistake 1 : 0.666796397198 2 : 0.104720338041 3 : 0.101412209496 4 : 0.0761078060714 5 : 0.0509632491938
iter =       10/   20000, loss_seg = 1.065, loss_adv_p = 0.474, loss_D = 0.706, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0158021057795 4 : 0.0165904591721 5 : 0.0267981756919 6 : 0.0441302710184 7 : 0.0592082596077 8 : 0.0964185397359 9 : 0.156169887236 10 : 0.584882301758
mistake 1 : 0.693582996679 2 : 0.111521356046 3 : 0.114913868543 4 : 0.0540284587444 5 : 0.0259533199871
iter =       11/   20000, loss_seg = 1.008, loss_adv_p = 0.442, loss_D = 0.724, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.00525433110547 4 : 0.0132143122692 5 : 0.0181547053526 6 : 0.0353758581661 7 : 0.0438108771263 8 : 0.057140850772 9 : 0.0862049023901 10 : 0.740844162818
mistake 1 : 0.698567424322 2 : 0.152478802192 3 : 0.0874738631406 4 : 0.0430223987244 5 : 0.0184575116206
iter =       12/   20000, loss_seg = 0.957, loss_adv_p = 0.442, loss_D = 0.743, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.00325696566801 4 : 0.0222856760435 5 : 0.029201151092 6 : 0.0513157992579 7 : 0.0617187558094 8 : 0.0804648983871 9 : 0.10184338308 10 : 0.649913370662
mistake 1 : 0.910824450383 2 : 0.0481368156486 3 : 0.0238175658895 4 : 0.0128053565929 5 : 0.00441581148608
iter =       13/   20000, loss_seg = 1.422, loss_adv_p = 0.830, loss_D = 0.618, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0149639294303 4 : 0.0338563367539 5 : 0.0443977524345 6 : 0.0406656984358 7 : 0.0549094069189 8 : 0.0846765553201 9 : 0.131008785505 10 : 0.595521535202
mistake 1 : 0.672923967008 2 : 0.177251807277 3 : 0.079961731589 4 : 0.0552797386879 5 : 0.0145827554388
iter =       14/   20000, loss_seg = 0.997, loss_adv_p = 1.154, loss_D = 0.837, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0193227170348 4 : 0.0525052357055 5 : 0.0709388706682 6 : 0.0720057691548 7 : 0.0699806377682 8 : 0.0911210337061 9 : 0.15581657249 10 : 0.468309163473
mistake 1 : 0.815034514788 2 : 0.1099266077 3 : 0.0481236762295 4 : 0.0204414796676 5 : 0.0064737216159
iter =       15/   20000, loss_seg = 0.938, loss_adv_p = 1.266, loss_D = 0.830, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.0138818759925 4 : 0.0214850753368 5 : 0.0245366000015 6 : 0.0524393902805 7 : 0.0717583953517 8 : 0.0829619547321 9 : 0.127629836154 10 : 0.605306872151
mistake 1 : 0.845186224931 2 : 0.0863806795239 3 : 0.037294979522 4 : 0.0189205299287 5 : 0.0122175860942
iter =       16/   20000, loss_seg = 1.872, loss_adv_p = 0.826, loss_D = 0.847, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.00228016623147 4 : 0.0089505718804 5 : 0.0161634364312 6 : 0.0323728439557 7 : 0.0323958295024 8 : 0.0427347284028 9 : 0.0699726012283 10 : 0.795129822368
mistake 1 : 0.815070892754 2 : 0.0785075426593 3 : 0.0488263539692 4 : 0.0359976506471 5 : 0.0215975599703
iter =       17/   20000, loss_seg = 1.717, loss_adv_p = 0.603, loss_D = 0.620, loss_semi = 0.000, loss_semi_adv = 0.000
correct 3 : 0.00200898622728 4 : 0.00781591308424 5 : 0.0269459263818 6 : 0.0502692998205 7 : 0.0554862862773 8 : 0.072763567832 9 : 0.10771673932 10 : 0.676993281057



            '''

            pred_re = F.softmax(pred, dim=1).repeat(1, 3, 1, 1)
            indices_1 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([0])).cuda())
            indices_2 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([1])).cuda())
            indices_3 = torch.index_select(
                images, 1,
                Variable(torch.LongTensor([2])).cuda())
            img_re = torch.cat([
                indices_1.repeat(1, 21, 1, 1),
                indices_2.repeat(1, 21, 1, 1),
                indices_3.repeat(1, 21, 1, 1),
            ], 1)

            mul_img = pred_re * img_re

            D_out = model_D(mul_img)

            loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, D_out))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            pred_re2 = F.softmax(pred, dim=1).repeat(1, 3, 1, 1)

            mul_img2 = pred_re2 * img_re

            D_out = model_D(mul_img2)

            loss_D = bce_loss(D_out, make_D_label(pred_label, D_out))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            img_gt, labels_gt, _, _ = batch
            img_gt = Variable(img_gt).cuda()
            D_gt_v = Variable(one_hot(labels_gt)).cuda()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            pred_re3 = D_gt_v.repeat(1, 3, 1, 1)
            indices_1 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([0])).cuda())
            indices_2 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([1])).cuda())
            indices_3 = torch.index_select(
                img_gt, 1,
                Variable(torch.LongTensor([2])).cuda())
            img_re3 = torch.cat([
                indices_1.repeat(1, 21, 1, 1),
                indices_2.repeat(1, 21, 1, 1),
                indices_3.repeat(1, 21, 1, 1),
            ], 1)

            mul_img3 = pred_re3 * img_re3

            D_out = model_D(mul_img3)

            loss_D = bce_loss(D_out, make_D_label(gt_label, D_out))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(args.num_steps) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' +
                    os.path.abspath(__file__).split('/')[-1].split('.')[0] +
                    '_' + str(i_iter) + '.pth'))
            #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Exemplo n.º 9
0
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    np.random.seed(args.random_seed)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)
    model.train()
    model.cuda(args.gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # load dataset
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        # labeled data
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        # unlabeled data
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)
    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        loss_seg_value = 0
        loss_adv_pred_value = 0

        loss_D_value = 0
        loss_D_ul_value = 0

        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # creating 2nd discriminator as a copy of the 1st one
        if i_iter == args.discr_split:
            model_D_ul = FCDiscriminator(num_classes=args.num_classes)
            model_D_ul.load_state_dict(net_D.state_dict())
            model_D_ul.train()
            model_D_ul.cuda(args.gpu)

            optimizer_D_ul = optim.Adam(model_D_ul.parameters(),
                                        lr=args.learning_rate_D,
                                        betas=(0.9, 0.99))

        # start training 2nd discriminator after specified number of steps
        if i_iter >= args.discr_split:
            optimizer_D_ul.zero_grad()
            adjust_learning_rate_D(optimizer_D_ul, i_iter)

        for sub_i in range(args.iter_size):
            # train Segmentation

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # don't accumulate grads in D_ul, in case split has already been made
            if i_iter >= args.discr_split:
                for param in model_D_ul.parameters():
                    param.requires_grad = False

            # do semi-supervised training first
            if args.lambda_semi_adv > 0 and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                pred_remain = pred.detach()

                # choose discriminator depending on the iteration
                if i_iter >= args.discr_split:
                    D_out = interp(model_D_ul(F.softmax(pred)))
                else:
                    D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)
                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                # adversarial loss
                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain,
                                        args.gpu))
                loss_semi_adv = loss_semi_adv / args.iter_size

                # true loss value without multiplier
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv
                loss_semi_adv.backward()

            else:

                loss_semi = None
                loss_semi_adv = None

            # train with labeled images
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            pred = interp(model(images))
            D_out = interp(model_D(F.softmax(pred)))

            # computing loss
            loss_seg = loss_calc(pred, labels, args.gpu)
            loss_adv_pred = bce_loss(
                D_out, make_D_label(gt_label, ignore_mask, args.gpu))
            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D and D_ul

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            if i_iter >= args.discr_split:
                for param in model_D_ul.parameters():
                    param.requires_grad = True

            # train D with pred
            pred = pred.detach()

            # before split, traing D with both labeled and unlabeled
            if args.D_remain and i_iter < args.discr_split and (
                    args.lambda_semi > 0 or args.lambda_semi_adv > 0):
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out,
                              make_D_label(pred_label, ignore_mask, args.gpu))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train D_ul with pred on unlabeled
            if i_iter >= args.discr_split and (args.lambda_semi > 0
                                               or args.lambda_semi_adv > 0):
                D_ul_out = interp(model_D_ul(F.softmax(pred_remain)))
                loss_D_ul = bce_loss(
                    D_ul_out,
                    make_D_label(pred_label, ignore_mask_remain, args.gpu))
                loss_D_ul = loss_D_ul / args.iter_size / 2
                loss_D_ul.backward()
                loss_D_ul_value += loss_D_ul.data.cpu().numpy()

            # get gt labels
            try:
                _, batch = trainloader_gt_iter.next()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.next()

            images_gt, labels_gt, _, _ = batch
            images_gt = Variable(images_gt).cuda(args.gpu)
            with torch.no_grad():
                pred_l = interp(model(images_gt))

            # train D with gt
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out,
                              make_D_label(gt_label, ignore_mask_gt, args.gpu))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train D_ul with pseudo_gt (gt are substituted for pred)
            if i_iter >= args.discr_split:
                D_ul_out = interp(model_D_ul(F.softmax(pred_l)))
                loss_D_ul = bce_loss(
                    D_ul_out, make_D_label(gt_label, ignore_mask_gt, args.gpu))
                loss_D_ul = loss_D_ul / args.iter_size / 2
                loss_D_ul.backward()
                loss_D_ul_value += loss_D_ul.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()
        if i_iter >= args.discr_split:
            optimizer_D_ul.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_D_ul={5:.3f}, loss_semi = {6:.3f}, loss_semi_adv = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_D_ul_value,
                    loss_semi_value, loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                net.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.random_seed) + '.pth'))
            torch.save(
                net_D.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.random_seed) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                net.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(i_iter) + '_' +
                    str(args.random_seed) + '.pth'))
            torch.save(
                net_D.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(i_iter) + '_' +
                    str(args.random_seed) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Exemplo n.º 10
0
def main():
    # 将参数的input_size 映射到整数,并赋值,从字符串转换到整数二元组
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = False
    gpu = args.gpu

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # 确保模型中参数的格式与要加载的参数相同
    # 返回一个字典,保存着module的所有状态(state);parameters和persistent buffers都会包含在字典中,字典的key就是parameter和buffer的 names。
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        # print (name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            # print('copy {}'.format(name))
    model.load_state_dict(new_params)

    # 设置为训练模式
    model.train()
    cudnn.benchmark = True

    model.cuda(gpu)

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)
    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))  # ?
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'),
                         'wb'))  # 写入文件

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear')  # ???

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):
        print("Iter:", i_iter)
        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = next(trainloader_remain_iter)
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = next(trainloader_remain_iter)

                # only access to img
                images, _, _, _ = batch

                images = Variable(images).cuda(gpu)
                # images = Variable(images).cpu()

                pred = interp(model(images))
                D_out = interp(model_D(F.softmax(pred)))

                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                # produce ignore mask
                semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt[semi_ignore_mask] = 255

                semi_ratio = 1.0 - float(
                    semi_ignore_mask.sum()) / semi_ignore_mask.size
                print('semi ratio: {:.4f}'.format(semi_ratio))

                if semi_ratio == 0.0:
                    loss_semi_value += 0
                else:
                    semi_gt = torch.FloatTensor(semi_gt)

                    loss_semi = args.lambda_semi * loss_calc(
                        pred, semi_gt, args.gpu)
                    loss_semi = loss_semi / args.iter_size
                    loss_semi.backward()
                    loss_semi_value += loss_semi.data.cpu().numpy(
                    )[0] / args.lambda_semi
            else:
                loss_semi = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch

            images = Variable(images).cuda(gpu)
            # images = Variable(images).cpu()

            ignore_mask = (labels.numpy() == 255)
            pred = interp(model(images))

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            )[0] / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            # D_gt_v = Variable(one_hot(labels_gt)).cpu()
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()[0]

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Exemplo n.º 11
0
def main():
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    np.random.seed(args.random_seed)

    # create network
    model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # load dataset
    train_dataset = VOCDataSet(args.data_dir,
                               args.data_list,
                               crop_size=input_size,
                               scale=args.random_scale,
                               mirror=args.random_mirror,
                               mean=IMG_MEAN)

    train_dataset_size = len(train_dataset)

    train_gt_dataset = VOCGTDataSet(args.data_dir,
                                    args.data_list,
                                    crop_size=input_size,
                                    scale=args.random_scale,
                                    mirror=args.random_mirror,
                                    mean=IMG_MEAN)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=5,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        #sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = np.arange(train_dataset_size)
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        # labeled data
        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      sampler=train_sampler,
                                      num_workers=3,
                                      pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         sampler=train_gt_sampler,
                                         num_workers=3,
                                         pin_memory=True)

        # unlabeled data
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        trainloader_remain = data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             sampler=train_remain_sampler,
                                             num_workers=3,
                                             pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader_remain)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # loss/bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    for i_iter in range(args.num_steps):
        loss_seg_value = 0
        loss_unlabeled_seg_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):
            # train Segmentation
            # train with labeled images
            try:
                _, batch = trainloader_iter.next()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.next()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            pred = interp(model(images))

            # computing loss
            loss_seg = loss_calc(pred, labels, args.gpu)

            # proper normalization
            loss = loss_seg / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size

            # train with unlabeled
            if args.lambda_semi > 0 and i_iter >= args.semi_start:
                try:
                    _, batch = trainloader_remain_iter.next()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.next()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                pred = interp(model(images))
                semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                semi_gt = torch.FloatTensor(semi_gt)

                loss_unlabeled_seg = args.lambda_semi * loss_calc(
                    pred, semi_gt, args.gpu)
                loss_unlabeled_seg = loss_unlabeled_seg / args.iter_size
                loss_unlabeled_seg.backward()

                loss_unlabeled_seg_value += loss_unlabeled_seg.data.cpu(
                ).numpy() / args.lambda_semi

            else:
                if args.lambda_semi > 0 and i_iter < args.semi_start:
                    loss_unlabeled_seg_value = 0

                else:
                    loss_unlabeled_seg_value = None

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_unlabeled_seg = {3:.3f} '
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_unlabeled_seg_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' +
                    str(args.lambda_semi) + '_' + str(args.random_seed) +
                    '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'VOC_' + str(i_iter) + '_' + str(args.lambda_semi) + '_' +
                    str(args.random_seed) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')