Example #1
0
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    torch.cuda.set_device(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        saved_state_dict = torch.load(
            args.restore_from,
            map_location=lambda storage, loc: storage.cuda(args.gpu))
        model.load_state_dict(saved_state_dict)

    # Create network


#    if args.model == 'DeepLab':
#        model = Res_Deeplab(num_classes=args.num_classes)
#   #     saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage.cuda(args.gpu))
#        saved_state_dict = torch.loa

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    print("amy", torch.cuda.current_device())
    ############################
    #validation data
    testloader = data.DataLoader(dataset.cityscapes_dataset.cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target_val,
        crop_size=input_size_target,
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set_val),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list = './dataset/cityscapes_list/label.txt'
    gt_imgs = open(label_path_list, 'r').read().splitlines()
    gt_imgs = [join('./data/Cityscapes/data/gtFine/val', x) for x in gt_imgs]

    interp_val = nn.Upsample(size=(com_size[1], com_size[0]), mode='bilinear')

    ############################

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    # interp_target = nn.UpsamplingBilinear2d(size=(input_size_target[1], input_size_target[0]))

    Softmax = torch.nn.Softmax()
    bce_loss = torch.nn.BCEWithLogitsLoss()

    for i_iter in range(args.num_steps):

        # loss_seg_value1 = 0
        loss_seg_value = 0
        loss_weak_value = 0
        loss_neg_value = 0
        loss_lse_source_value = 0
        loss_lse_target_value = 0
        entropy_samples_value = 0
        model.train()
        optimizer.zero_grad()

        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):

            # train with source

            #   _, batch = next(trainloader_iter)
            #   images, labels, class_label_source, mask_weakly, _, name = batch
            #   images = Variable(images).cuda(args.gpu)
            #   pred = model(images)
            #   pred = interp(pred)

            #   loss_seg = loss_calc(pred, labels, args.gpu)

            #   num = torch.sum(mask_weakly[0][0]).data.item()
            #   class_label_source_lse = class_label_source.type(torch.FloatTensor)
            #   exp_source = torch.min(torch.exp(1*pred), Variable(torch.exp(torch.tensor(40.0))).cuda(args.gpu))
            #   lse  = (1.0/1) * torch.log( (512*256/num) * AvePool(torch.exp(1*pred) * mask_weakly.type(torch.FloatTensor).cuda(args.gpu)))
            #   loss_lse_source = bce_loss(lse, Variable(class_label_source_lse.reshape(lse.

            _, batch = next(targetloader_iter)
            images, class_label, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            _, pred_target = model(images)

    #   optimizer.step()
        del pred_target, batch, images

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f} loss_lse_source = {3:.3f} loss_lse_target = {4:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_lse_source_value, loss_lse_target_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            model.eval()
            hist = np.zeros((19, 19))

            f = open(args.results_dir, 'a')
            for index, batch in enumerate(testloader):
                print(index)
                image, _, name = batch
                _, output = model(
                    Variable(image, volatile=True).cuda(args.gpu))
                pred = interp_val(output)
                pred = pred[0].permute(1, 2, 0)
                pred = torch.max(pred, 2)[1].byte()
                pred_cpu = pred.data.cpu().numpy()
                del pred, output
                label = Image.open(gt_imgs[index])
                label = np.array(label.resize(com_size, Image.NEAREST))
                label = label_mapping(label, mapping)
                hist += fast_hist(label.flatten(), pred_cpu.flatten(), 19)

            mIoUs = per_class_iu(hist)
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            print(mIoU)
            f.write('i_iter:{:d},        miou:{:0.5f} \n'.format(i_iter, mIoU))
            f.close()
Example #2
0
    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(
        cityscapesDataSet(args.data_dir_target, args.data_list_target, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size_target,
                    scale=False, mirror=args.random_mirror, mean=IMG_MEAN, set=args.set),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
   # interp_target = nn.UpsamplingBilinear2d(size=(input_size_target[1], input_size_target[0]))
Example #3
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu
    torch.cuda.set_device(args.gpu)

    ############################
    #validation data
    testloader = data.DataLoader(cityscapesDataSet(args.data_dir_target,
                                                   args.data_list_target_val,
                                                   crop_size=input_size_target,
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set_val),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list = './dataset/cityscapes_list/label.txt'
    gt_imgs = open(label_path_list, 'r').read().splitlines()
    gt_imgs = [
        osp.join('./data/Cityscapes/data/gtFine/val', x) for x in gt_imgs
    ]

    interp_val = nn.UpsamplingBilinear2d(size=(com_size[1], com_size[0]))

    ############################

    # Create network
    #  if args.model == 'DeepLab':
    #      model = Res_Deeplab(num_classes=args.num_classes)
    #      if args.restore_from[:4] == 'http' :
    #          saved_state_dict = model_zoo.load_url(args.restore_from)
    #      else:
    #           saved_state_dict = torch.load(args.restore_from, map_location=lambda storage, loc: storage.cuda(args.gpu))
    #
    #      new_params = model.state_dict().copy()
    #      for i in saved_state_dict:
    #          # Scale.layer5.conv2d_list.3.weight
    #          i_parts = i.split('.')
    #          # print i_parts
    #          if not args.num_classes == 19 or not i_parts[1] == 'layer5':
    #              new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    #              # print i_parts
    #      model.load_state_dict(new_params)

    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        saved_state_dict = torch.load(
            args.restore_from,
            map_location=lambda storage, loc: storage.cuda(args.gpu))
        model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1
    AvePool = torch.nn.AvgPool2d(kernel_size=(512, 1024))

    for i_iter in range(args.num_steps):
        model.train()
        loss_lse_target_two_value = 0
        loss_lse_target_value = 0
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = next(trainloader_iter)
            images, labels, class_label_source, mask_weakly, _, name = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.item() / args.iter_size
            loss_seg_value2 += loss_seg2.data.item() / args.iter_size

            # train with target

            _, batch = next(targetloader_iter)
            images, class_label, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            class_label_target_lse = class_label.type(torch.FloatTensor)
            exp_target = torch.min(
                torch.exp(1 * pred_target2),
                Variable(torch.exp(torch.tensor(40.0))).cuda(args.gpu))
            lse = (1.0 / 1) * torch.log(AvePool(exp_target))
            loss_lse_target = bce_loss(
                lse,
                Variable(class_label_target_lse.reshape(lse.size())).cuda(
                    args.gpu))

            #    exp_target = torch.min(torch.exp(1*pred_target1), Variable(torch.exp(torch.tensor(40.0))).cuda(args.gpu))
            #    lse  = (1.0/1) * torch.log(AvePool(exp_target))
            #    loss_lse_target_two = bce_loss(lse, Variable(class_label_target_lse.reshape(lse.size())).cuda(args.gpu))

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2 + 0.2 * loss_lse_target  # + 0.02 * loss_lse_target_two
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.item(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.item(
            ) / args.iter_size
            loss_lse_target_value += loss_lse_target.data.item(
            ) / args.iter_size
            #   loss_lse_target_two_value += loss_lse_target_two.data.item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))
            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.item()
            loss_D_value2 += loss_D2.data.item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.item()
            loss_D_value2 += loss_D2.data.item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()
        del D_out1, D_out2, pred1, pred2, pred_target1, pred_target2, images, labels

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f} loss_lse_target = {8:.3f} loss_lse_target2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_lse_target_value,
                    loss_lse_target_two_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(),
                       osp.join(args.snapshot_dir, 'GTA5' + '.pth'))
            torch.save(model_D2.state_dict(),
                       osp.join(args.snapshot_dir, 'GTA5_' + '_D2.pth'))
            torch.save(model_D1.state_dict(),
                       osp.join(args.snapshot_dir, 'GTA5_' + '_D1.pth'))

            hist = np.zeros((19, 19))
            #       model.cuda(0)
            model.eval()
            f = open(args.results_dir, 'a')
            for index, batch in enumerate(testloader):
                print(index)
                image, _, _, name = batch
                output1, output2 = model(
                    Variable(image, volatile=True).cuda(args.gpu))
                pred = interp_val(output2)
                pred = pred[0].permute(1, 2, 0)
                pred = torch.max(pred, 2)[1].byte()
                pred_cpu = pred.data.cpu().numpy()
                del pred, output1, output2
                label = Image.open(gt_imgs[index])
                label = np.array(label.resize(com_size, Image.NEAREST))
                label = label_mapping(label, mapping)
                hist += fast_hist(label.flatten(), pred_cpu.flatten(), 19)
    #      model.cuda(args.gpu)
            mIoUs = per_class_iu(hist)
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            print(mIoU)
            f.write('i_iter:{:d},        miou:{:0.5f} \n'.format(i_iter, mIoU))
            f.close()
def main():
    """Create the model and start the training."""

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    h, w = map(int, args.com_size.split(','))
    com_size = (h, w)

    h, w = map(int, args.input_size_target.split(','))
    input_size_target = (h, w)

    ############################
    #validation data
    testloader = data.DataLoader(dataset.cityscapes_dataset.cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target_val,
        crop_size=input_size,
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set_val),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)
    with open('./dataset/cityscapes_list/info.json', 'r') as fp:
        info = json.load(fp)
    mapping = np.array(info['label2train'], dtype=np.int)
    label_path_list = './dataset/cityscapes_list/label.txt'
    gt_imgs = open(label_path_list, 'r').read().splitlines()
    gt_imgs = [join('./data/Cityscapes/data/gtFine/val', x) for x in gt_imgs]

    interp_val = nn.Upsample(size=(com_size[1], com_size[0]), mode='bilinear')

    ############################

    cudnn.enabled = True

    # Create network
    if args.model == 'DeepLab':
        model = Res_Deeplab(num_classes=args.num_classes)
        #   if args.restore_from[:4] == 'http' :
        #       saved_state_dict = model_zoo.load_url(args.restore_from)
        #   else:
        saved_state_dict = torch.load(args.restore_from)

        #new_params = model.state_dict().copy()
        #   for i in saved_state_dict:
        #       # Scale.layer5.conv2d_list.3.weight
        #       i_parts = i.split('.')
        #       # print i_parts
        #       if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #           new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        # print i_parts
        model.load_state_dict(saved_state_dict)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    # interp_target = nn.UpsamplingBilinear2d(size=(input_size_target[1], input_size_target[0]))

    Softmax = torch.nn.Softmax()
    AvePool = torch.nn.AvgPool2d(kernel_size=(256, 512))
    bce_loss = torch.nn.BCEWithLogitsLoss()

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        entropy_samples_value = 0
        model.train()
        optimizer.zero_grad()

        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):

            # train with source
            for param in model.layer6.parameters():
                #       print("fengmao,",param)
                param.requires_grad = True

            _, batch = next(trainloader_iter)
            images, labels, class_label_source, _, name = batch

            images = Variable(images).cuda(args.gpu)
            pred = model(images)
            pred = interp(pred)
            loss_seg = loss_calc(pred, labels, args.gpu)

            loss = loss_seg
            loss.backward()
            loss_seg_value += loss_seg.data.item() / args.iter_size

            # train with target

            for param in model.layer6.parameters():
                #       print("fengmao,",param)
                param.requires_grad = False

            _, batch = next(targetloader_iter)
            images, class_label, _, _ = batch

            images = Variable(images).cuda(args.gpu)
            pred_target = model(images)
            pred_target = interp(pred_target)

            prob_tar = F.softmax(pred_target)
            log_prob_tar = torch.log(prob_tar + 1e-45)
            entropy_samples = -torch.sum(torch.mul(prob_tar, log_prob_tar)) / (
                input_size_target[0] * input_size_target[1])
            loss = 0.1 * entropy_samples

            loss = loss / args.iter_size
            loss.backward()
            entropy_samples_value += entropy_samples.data.item(
            ) / args.iter_size

            optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}  entropy_samples = {3:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    entropy_samples_value))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            model.eval()
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            hist = np.zeros((19, 19))

            f = open(args.results_dir, 'a')
            for index, batch in enumerate(testloader):
                print(index)
                image, _, name = batch
                output = model(Variable(image, volatile=True).cuda(args.gpu))
                pred = interp_val(output)
                pred = pred[0].permute(1, 2, 0)
                pred = torch.max(pred, 2)[1].byte()
                pred = pred.data.cpu().numpy()
                label = Image.open(gt_imgs[index])
                label = np.array(label.resize(com_size, Image.NEAREST))
                label = label_mapping(label, mapping)
                hist += fast_hist(label.flatten(), pred.flatten(), 19)

            mIoUs = per_class_iu(hist)
            mIoU = round(np.nanmean(mIoUs) * 100, 2)
            print(mIoU)
            f.write('i_iter:{:d},        miou:{:0.5f} \n'.format(i_iter, mIoU))
            f.close()