Beispiel #1
0
def main():

    #args = opt.initialize()  
    args = get_arguments()
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(args.cuda_device_id)

    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes)

    device = torch.device("cuda" if not args.cpu else "cpu")

    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.to(device)  
    
    targetloader_center = data.DataLoader(cityscapesDataSetLabel(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    count_class = np.zeros((19, 1))
    class_center_temp = np.zeros((19, 256))
    
    
    for index, batch in enumerate(targetloader_center):
        if index % 100 == 0:
            print( '%d processd' % index)
        images, labels, _, _= batch
        images = images.to(device)
        labels = labels.long().to(device)
        

        with torch.no_grad():
            feature, _ = model(images)
        
        class_center,count_class_t = class_center_precal(feature,labels)
        count_class += count_class_t.numpy()
        class_center_temp += class_center.cpu().data[0].numpy()

    
    
    count_class[count_class==0] = 1              #in case divide 0 error
    
    class_center = class_center_temp/count_class
    np.save('./target_center.npy',class_center)
Beispiel #2
0
def main():
    """Create the model and start the training."""
    w, h = map(int, args.input_size_source.split(','))
    input_size_source = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

        # model.load_state_dict(saved_state_dict)

    elif args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes,
                           pretrained=True,
                           vgg16_caffe_path=args.restore_from)

        # saved_state_dict = torch.load(args.restore_from)
        # model.load_state_dict(saved_state_dict)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    model.train()
    model.cuda(args.gpu)
    cudnn.benchmark = True

    #Discrimintator setting
    model_D = FCDiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda(args.gpu)

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    # labels for adversarial training
    source_adv_label = 0
    target_adv_label = 1

    #Dataloader
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.translated_data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size_source,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    style_trainloader = data.DataLoader(GTA5DataSet(
        args.stylized_data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_source,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        pin_memory=True)

    style_trainloader_iter = enumerate(style_trainloader)

    if STAGE == 1:
        targetloader = data.DataLoader(cityscapesDataSet(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    else:
        #Dataloader for self-training
        targetloader = data.DataLoader(cityscapesDataSetLabel(
            args.data_dir_target,
            args.data_list_target,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_target,
            mean=IMG_MEAN,
            set=args.set,
            label_folder='Path to generated pseudo labels'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True)

        targetloader_iter = enumerate(targetloader)

    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # load checkpoint
    model, model_D, optimizer, start_iter = load_checkpoint(
        model,
        model_D,
        optimizer,
        filename=args.snapshot_dir + 'checkpoint_' + CHECKPOINT + '.pth.tar')

    for i_iter in range(start_iter, args.num_steps):
        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        #train segementation network
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source
        if STAGE == 1:
            if i_iter % 2 == 0:
                _, batch = next(trainloader_iter)
            else:
                _, batch = next(style_trainloader_iter)

        else:
            _, batch = next(trainloader_iter)

        image_source, label, _, _ = batch
        image_source = Variable(image_source).cuda(args.gpu)

        pred_source = model(image_source)
        pred_source = interp(pred_source)

        loss_seg_source = loss_calc(pred_source, label, args.gpu)
        loss_seg_source_value = loss_seg_source.item()
        loss_seg_source.backward()

        if STAGE == 2:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, target_label, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #target segmentation loss
            loss_seg_target = loss_calc(pred_target,
                                        target_label,
                                        gpu=args.gpu)
            loss_seg_target.backward()

        # optimize
        optimizer.step()

        if STAGE == 1:
            # train with target
            _, batch = next(targetloader_iter)
            image_target, _, _ = batch
            image_target = Variable(image_target).cuda(args.gpu)

            pred_target = model(image_target)
            pred_target = interp_target(pred_target)

            #output-level adversarial training
            D_output_target = model_D(F.softmax(pred_target))
            loss_adv = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_adv = loss_adv * args.lambda_adv
            loss_adv.backward()

            #train discriminator
            for param in model_D.parameters():
                param.requires_grad = True

            pred_source = pred_source.detach()
            pred_target = pred_target.detach()

            D_output_source = model_D(F.softmax(pred_source))
            D_output_target = model_D(F.softmax(pred_target))

            loss_D_source = bce_loss(
                D_output_source,
                Variable(
                    torch.FloatTensor(D_output_source.data.size()).fill_(
                        source_adv_label)).cuda(args.gpu))
            loss_D_target = bce_loss(
                D_output_target,
                Variable(
                    torch.FloatTensor(D_output_target.data.size()).fill_(
                        target_adv_label)).cuda(args.gpu))

            loss_D_source = loss_D_source / 2
            loss_D_target = loss_D_target / 2

            loss_D_source.backward()
            loss_D_target.backward()

            #optimize
            optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print('iter = {0:8d}/{1:8d}, loss_seg_source = {2:.5f}'.format(
            i_iter, args.num_steps, loss_seg_source_value))

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            state = {
                'iter': i_iter,
                'model': model.state_dict(),
                'model_D': model_D.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                state,
                osp.join(args.snapshot_dir,
                         'checkpoint_' + str(i_iter) + '.pth.tar'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_D_' + str(i_iter) + '.pth'))

            cityscapes_eval_dir = osp.join(args.cityscapes_eval_dir,
                                           str(i_iter))
            if not os.path.exists(cityscapes_eval_dir):
                os.makedirs(cityscapes_eval_dir)

            eval(osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'),
                 cityscapes_eval_dir, i_iter)

            iou19, iou13, iou = compute_mIoU(cityscapes_eval_dir, i_iter)
            outputfile = open(args.output_file, 'a')
            outputfile.write(
                str(i_iter) + '\t' + str(iou19) + '\t' +
                str(iou.replace('\n', ' ')) + '\n')
            outputfile.close()
Beispiel #3
0
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMultiFeature(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    model_D2 = FCDiscriminator(num_classes=args.num_classes).to(device)

    model_D2.train()
    model_D2.to(device)


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    cityset = cityscapesDataSetLabel(args.data_dir_target, args.data_list_target,
                                    max_iters=args.num_steps * args.iter_size * args.batch_size,
                                    crop_size=input_size_target,
                                    mean=IMG_MEAN,
                                    set=args.set, label_folder=LABEL_DIRECTORY_TARGET)
    targetloader = data.DataLoader(cityset,
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    # init cls D
    model_clsD = []
    optimizer_clsD = []
    for i in range(args.num_classes):
        model_temp = FCDiscriminatorCLS(num_classes=args.num_classes).to(device).train()
        optimizer_temp = optim.Adam(model_temp.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
        optimizer_temp.zero_grad()
        #model_temp, optimizer_temp = amp.initialize(
        #    model_temp, optimizer_temp, opt_level="O1", 
        #    keep_batchnorm_fp32=None, loss_scale="dynamic"
        #)
        model_temp, optimizer_temp = amp.initialize(
            model_temp, optimizer_temp, opt_level="O1", 
            keep_batchnorm_fp32=None, loss_scale="dynamic"
        )
        model_clsD.append(model_temp)
        optimizer_clsD.append(optimizer_temp)

    model, optimizer = amp.initialize(
        model, optimizer, opt_level="O1", 
        keep_batchnorm_fp32=None, loss_scale="dynamic"
    )

    model_D2, optimizer_D2 = amp.initialize(
        model_D2, optimizer_D2, opt_level="O1", 
        keep_batchnorm_fp32=None, loss_scale="dynamic"
    )

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    
    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_cls_adv = 0
        loss_cls_adv_value = 0
        loss_cls_D = 0
        loss_cls_D_value = 0
        loss_self_seg_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for i in range(args.num_classes):
            optimizer_clsD[i].zero_grad()
            adjust_learning_rate_D(optimizer_clsD[i], i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D2.parameters():
                param.requires_grad = False

            for i in range(args.num_classes):
                for param in model_clsD[i].parameters():
                    param.requires_grad = False

            # train with source
            
            _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = images.to(device)
            labels = labels.long().to(device)

            _, pred2 = model(images)
            pred2 = interp(pred2)

            loss_seg2 = seg_loss(pred2, labels)
            loss = loss_seg2

            # proper normalization
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_seg_value2 += loss_seg2.item() / args.iter_size
            
            # train with target

            _, batch = targetloader_iter.__next__()
            images, target_labels, _, name = batch
            images = images.to(device)
            target_labels = target_labels.long().to(device)

            _, pred_target2 = model(images)
            pred_target2 = interp_target(pred_target2)

            pred_target_score = F.softmax(pred_target2, dim=1)
            D_out2 = model_D2(pred_target_score)
            loss_adv_target2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))

            loss_self_seg = seg_loss(pred_target2, target_labels)
            loss_self_seg = loss_self_seg / args.iter_size
            loss_self_seg_value = loss_self_seg.item()
            
            _, target_pred_cls = torch.max(pred_target_score, dim=1)
            target_pred_cls = target_pred_cls.long().detach()
            for i in range(args.num_classes):
                cls_mask = (target_pred_cls==i) * (target_labels==i)
                if torch.sum(cls_mask) == 0:
                    continue
                cls_gt = torch.tensor(target_labels.data).long().to(device)
                cls_gt[~cls_mask] = 255
                cls_gt[cls_mask] = source_label
                cls_out = model_clsD[i](pred_target_score)
                loss_cls_adv += seg_loss(cls_out, cls_gt)
            loss_cls_adv_value = loss_cls_adv.item() / args.iter_size
                    

            loss = args.lambda_adv_target2 * loss_adv_target2 + LAMBDA_CLS_ADV * loss_cls_adv + loss_self_seg
            loss = loss / args.iter_size
            amp_backward(loss, optimizer)
            loss_adv_target_value2 += loss_adv_target2.item() / args.iter_size


            # train D

            # bring back requires_grad

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_label).to(device))
            loss_D2 = loss_D2 / args.iter_size / 2
            amp_backward(loss_D2, optimizer_D2)
            loss_D_value2 += loss_D2.item()

            # train with target
            pred_target2 = pred_target2.detach()
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D2 = bce_loss(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(target_label).to(device))
            loss_D2 = loss_D2 / args.iter_size / 2
            amp_backward(loss_D2, optimizer_D2)
            loss_D_value2 += loss_D2.item()

            for i in range(args.num_classes):
                for param in model_clsD[i].parameters():
                    param.requires_grad = True

            pred_source_score = F.softmax(pred2, dim=1)
            _, source_pred_cls = torch.max(pred_source_score, dim=1)
            source_pred_cls = source_pred_cls.long().detach()
            for i in range(args.num_classes):
                cls_mask = (source_pred_cls==i) * (labels==i)
                if torch.sum(cls_mask) == 0:
                    continue
                cls_gt = torch.tensor(source_pred_cls.data).long().to(device)
                cls_gt[~cls_mask] = 255
                cls_gt[cls_mask] = source_label
                cls_out = model_clsD[i](pred_source_score)
                loss_cls_D = seg_loss(cls_out, cls_gt) / 2
                amp_backward(loss_cls_D, optimizer_clsD[i])
                loss_cls_D_value += loss_cls_D.item()

            pred_target_score = F.softmax(pred_target2, dim=1)
            _, target_pred_cls = torch.max(pred_target_score, dim=1)
            target_pred_cls = target_pred_cls.long().detach()
            for i in range(args.num_classes):
                cls_mask = (target_pred_cls==i) * (target_labels==i)
                if torch.sum(cls_mask) == 0:
                    continue
                cls_gt = torch.tensor(target_pred_cls.data).long().to(device)
                cls_gt[~cls_mask] = 255
                cls_gt[cls_mask] = target_label
                cls_out = model_clsD[i](pred_target_score)
                loss_cls_adv += seg_loss(cls_out, cls_gt)
                loss_cls_D = seg_loss(cls_out, cls_gt) / 2
                amp_backward(loss_cls_D, optimizer_clsD[i])
                loss_cls_D_value += loss_cls_D.item()



        optimizer.step()
        optimizer_D2.step()
        for i in range(args.num_classes):
            optimizer_clsD[i].step()

        if args.tensorboard:
            scalar_info = {
                'loss_seg2': loss_seg_value2,
                'loss_adv_target2': loss_adv_target_value2,
                'loss_D2': loss_D_value2,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg2 = {2:.3f}, loss_adv2 = {3:.3f} loss_D2 = {4:.3f} loss_cls_adv = {5:.3f} loss_cls_D = {6:.3f} loss_self_seg = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value2, loss_adv_target_value2, loss_D_value2, loss_cls_adv_value, loss_cls_D_value, loss_self_seg_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            for i in range(args.num_classes):
                torch.save(model_clsD[i].state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(NUM_STEPS) + '_clsD.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
            for i in range(args.num_classes):
                torch.save(model_clsD[i].state_dict(), osp.join(args.snapshot_dir, 'GTA5_clsD'+str(i)+'.pth'))

    if args.tensorboard:
        writer.close()
def main():
    """Create the model and start the training."""

    device = torch.device("cuda" if not args.cpu else "cpu")

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True

    bestIoU = 0
    bestIter = 0

    # Create network
    if args.model == 'ResNet':
        model = DeeplabMulti(num_classes=args.num_classes)
        saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict)

    if args.model == 'VGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        saved_state_dict = torch.load(args.restore_from)
        model.load_state_dict(saved_state_dict)

    model.train()
    model.to(device)

    cudnn.benchmark = True

    # init D
    if args.model == 'ResNet':
        model_D = FCDiscriminator(num_classes=256).to(device)
        saved_state_dict = torch.load('./snapshots/BestGTA5_D.pth')
        model_D.load_state_dict(saved_state_dict)
    if args.model == 'VGG':
        model_D = FCDiscriminator(num_classes=256).to(device)

    model_D.train()
    model_D.to(device)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSetLabel(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    seg_loss = torch.nn.CrossEntropyLoss(ignore_index=255)

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    test_interp = nn.Upsample(size=(1024, 2048),
                              mode='bilinear',
                              align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    # load calculated  class center for initilization
    class_center_source_ori = np.load('./source_center.npy')
    class_center_source_ori = torch.from_numpy(class_center_source_ori)

    class_center_target_ori = np.load('./target_center.npy')
    class_center_target_ori = torch.from_numpy(class_center_target_ori)

    # set up tensor board
    if args.tensorboard:
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)

        writer = SummaryWriter(args.log_dir)

    for i_iter in range(args.num_steps):

        loss_seg = 0
        loss_adv_target_value = 0
        loss_D_value = 0
        loss_cla_value = 0
        loss_square_value = 0
        loss_st_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # train with source

        _, batch = trainloader_iter.__next__()
        images, labels, _, _ = batch
        images = images.to(device)
        labels_s = labels  # copy for center calculation
        labels = labels.long().to(device)

        feature, prediction = model(images)
        feature_s = feature  # copy for center calculation
        prediction = interp(prediction)
        loss = seg_loss(prediction, labels)
        loss.backward(retain_graph=True)
        loss_seg = loss.item()

        # train with target

        _, batch = targetloader_iter.__next__()
        images, labels_pseudo, _, _ = batch
        labels_t = labels_pseudo  # copy for center calculation
        images = images.to(device)
        labels_pseudo = labels_pseudo.long().to(device)

        feature_target, pred_target = model(images)
        feature_t = feature_target  # copy for center calculation
        _, D_out = model_D(feature_target)
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        #print(args.lambda_adv_target)
        loss = args.lambda_adv_target * loss_adv_target
        loss.backward(retain_graph=True)
        loss_adv_target_value = loss_adv_target.item()

        pred_target = interp_target(pred_target)
        loss_st = seg_loss(pred_target, labels_pseudo)
        loss_st.backward(retain_graph=True)
        loss_st_value = loss_st.item()

        # class center alignment begin
        if i_iter > 10000:
            class_center_source = class_center_cal(feature_s, labels_s)
            class_center_target = class_center_cal(feature_t, labels_t)
            class_center_source_ori = class_center_update(
                class_center_source, class_center_source_ori,
                args.lambda_center_update)
            class_center_target_ori = class_center_update(
                class_center_target, class_center_target_ori,
                args.lambda_center_update)

            class_center_source_ori = class_center_source_ori.detach(
            )  #align target center to source

            center_diff = class_center_source_ori - class_center_target_ori
            loss_square = torch.pow(center_diff, 2).sum()

            loss = args.lambda_center * loss_square
            loss.backward()
            loss_square_value = loss_square.item()
        # class center alignment end

        # train D

        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        feature = feature.detach()
        cla, D_out = model_D(feature)
        cla = interp(cla)
        loss_cla = seg_loss(cla, labels)

        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(source_label).to(device))
        loss_D = loss_D / 2
        #print(args.lambda_s)
        loss_Disc = args.lambda_s * loss_cla + loss_D
        loss_Disc.backward()

        loss_cla_value = loss_cla.item()
        loss_D_value = loss_D.item()

        # train with target
        feature_target = feature_target.detach()
        _, D_out = model_D(feature_target)
        loss_D = bce_loss(
            D_out,
            torch.FloatTensor(
                D_out.data.size()).fill_(target_label).to(device))
        loss_D = loss_D / 2
        loss_D.backward()
        loss_D_value += loss_D.item()

        optimizer.step()
        optimizer_D.step()

        class_center_target_ori = class_center_target_ori.detach()

        if args.tensorboard:
            scalar_info = {
                'loss_seg': loss_seg,
                'loss_cla': loss_cla_value,
                'loss_adv_target': loss_adv_target_value,
                'loss_st_value': loss_st_value,
                'loss_D': loss_D_value,
            }

            if i_iter % 10 == 0:
                for key, val in scalar_info.items():
                    writer.add_scalar(key, val, i_iter)

        #print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_adv = {3:.3f} loss_D = {4:.3f} loss_cla = {5:.3f} loss_st = {6:.5f} loss_square = {7:.5f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv_target_value,
                    loss_D_value, loss_cla_value, loss_st_value,
                    loss_square_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            if not os.path.exists(args.save):
                os.makedirs(args.save)
            testloader = data.DataLoader(cityscapesDataSet(
                args.data_dir_target,
                args.data_list_target_test,
                crop_size=(1024, 512),
                mean=IMG_MEAN,
                scale=False,
                mirror=False,
                set='val'),
                                         batch_size=1,
                                         shuffle=False,
                                         pin_memory=True)
            model.eval()
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processd' % index)
                image, _, name = batch
                with torch.no_grad():
                    output1, output2 = model(Variable(image).to(device))
                output = test_interp(output2).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output = Image.fromarray(output)
                name = name[0].split('/')[-1]
                output.save('%s/%s' % (args.save, name))
            mIoUs = compute_mIoU(osp.join(args.data_dir_target, 'gtFine/val'),
                                 args.save, 'dataset/cityscapes_list')
            mIoU = round(np.nanmean(mIoUs) * 100, 2)

            print('===>  current   mIoU: ' + str(mIoU))
            print('===> last best  mIoU: ' + str(bestIoU))
            print('===> last best  iter: ' + str(bestIter))

            if mIoU > bestIoU:
                bestIoU = mIoU
                bestIter = i_iter
                torch.save(model.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5.pth'))
                torch.save(model_D.state_dict(),
                           osp.join(args.snapshot_dir, 'BestGTA5_D.pth'))
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))
            model.train()

    if args.tensorboard:
        writer.close()