Ejemplo n.º 1
0
def make_data_loader(config):

    train_set = spacenet.Spacenet(city=config.dataset,
                                  split='train',
                                  img_root=config.img_root,
                                  source_dist=config.source_dist)
    val_set = spacenet.Spacenet(city=config.dataset,
                                split='val',
                                img_root=config.img_root,
                                source_dist=config.source_dist)

    num_class = train_set.NUM_CLASSES
    train_loader = DataLoader(train_set,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.train_num_workers,
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.val_num_workers,
                            drop_last=True)
    test_loader = None

    return train_loader, val_loader, test_loader, num_class
Ejemplo n.º 2
0
    def __init__(self, model_path, source, target, cuda=False):
        self.source_set = spacenet.Spacenet(city=source,
                                            split='test',
                                            img_root=config.img_root)
        self.target_set = spacenet.Spacenet(city=target,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)
        self.target_loader = DataLoader(self.target_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
Ejemplo n.º 3
0
def make_data_loader(config):

    train_set = spacenet.Spacenet(city=config.dataset,
                                  split='train',
                                  img_root=config.img_root,
                                  gt_root=config.gt_root,
                                  mean_std=config.mean_std,
                                  if_augment=config.if_augment,
                                  repeat_count=config.repeat_count,
                                  transform_sample=config.transform_sample)
    val_set = spacenet.Spacenet(city=config.dataset,
                                split='val',
                                img_root=config.img_root,
                                gt_root=config.gt_root,
                                mean_std=config.mean_std,
                                if_augment=config.if_augment,
                                repeat_count=config.repeat_count,
                                transform_sample=config.transform_sample)

    #test_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root)

    num_class = train_set.NUM_CLASSES
    train_loader = DataLoader(train_set,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.train_num_workers,
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.val_num_workers)
    #test_loader = DataLoader(test_set, batch_size=config.batch_size, shuffle=False, num_workers=config.val_num_workers)
    test_loader = None
    #val_loader = None
    return train_loader, val_loader, test_loader, num_class
Ejemplo n.º 4
0
    def __init__(self, model_path, config, cuda=False):
        self.target=config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set, batch_size=16, shuffle=False, num_workers=2)

        self.target_set = []
        self.target_loader = []
        # load other domains
        for city in self.target:
            tmp = spacenet.Spacenet(city=city, split='test', img_root=config.img_root)
            self.target_set.append(tmp)
            self.target_loader.append(DataLoader(tmp, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                backbone=config.backbone,
                output_stride=config.out_stride,
                sync_bn=config.sync_bn,
                freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
Ejemplo n.º 5
0
def compute_mean_variance(city='Vegas'):
    #ini_dist = {'mean': (0.,0.,0.), 'std': (1., 1., 1.)}
    train_set = spacenet.Spacenet(city=city,
                                  split='train',
                                  img_root='/usr/xtmp/satellite/spacenet/')
    loader = DataLoader(train_set,
                        batch_size=config.batch_size,
                        shuffle=False,
                        num_workers=config.train_num_workers,
                        drop_last=True)
    fake_train_set = spacenet.Spacenet(
        city=city,
        split='train',
        img_root='/usr/xtmp/satellite/FakeShanghai-Vegas/')
    fake_loader = DataLoader(fake_train_set,
                             batch_size=config.batch_size,
                             shuffle=False,
                             num_workers=config.train_num_workers,
                             drop_last=True)
    mean = 0.
    std = 0.
    nb_samples = 0.
    all_data = None
    for sample in tqdm(loader):
        data = sample['image']
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        #if all_data is None:
        #    all_data = data
        #else:
        #    torch.cat((all_data, data), 0)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples

    for sample in tqdm(fake_loader):
        data = sample['image']
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        #if all_data is None:
        #    all_data = data
        #else:
        #    torch.cat((all_data, data), 0)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples
    #mean = data.mean(2).sum(0)
    #std = data.std(2).sum(0)
    mean = mean.numpy().astype(float)
    std = std.numpy().astype(float)
    print('city: ', city)
    print('mean: ', mean)
    print('std:', std)
    return (mean[0], mean[1], mean[2]), (std[0], std[1], std[2])
    def __init__(self, model_path, config, bn, save_path, save_batch, sample_number, trial=100, cuda=False):
        self.bn = bn
        self.target=config.all_dataset
        self.target.remove(config.dataset)
        self.sample_number = sample_number
        # load source domain
        #self.source_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root, needs to be changed)
        #self.source_loader = DataLoader(self.source_set, batch_size=16, shuffle=False, num_workers=2)
        self.source_loader = None
        self.save_path = save_path
        self.save_batch = save_batch
        self.trial = trial
        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city, split='val', img_root=config.img_root, gt_root = config.gt_root, mean_std=config.mean_std, if_augment=config.if_augment, repeat_count=config.repeat_count)
            self.target_set.append(test)
            self.target_loader.append(DataLoader(test, batch_size=16, shuffle=False, num_workers=2))

            train = spacenet.Spacenet(city=city, split='train', img_root=config.img_root, gt_root = config.gt_root, mean_std=config.mean_std, if_augment=config.if_augment, repeat_count=config.repeat_count, sample_number= sample_number)
            self.target_trainset.append(train)
            self.target_trainloader.append(DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

            
        self.model = DeepLab(num_classes=2,
                backbone=config.backbone,
                output_stride=config.out_stride,
                sync_bn=config.sync_bn,
                freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        #self.model.load_state_dict(self.checkpoint)
        self.model.load_state_dict(self.checkpoint['model'])

        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
Ejemplo n.º 7
0
def getImagesActivations():

    images = []
    activations = []
    gts = []
    cities = []
    city_count = []

    target = ['Vegas', 'Shanghai']
    save_path = '/home/home1/swarnakr/main/DomainAdaptation/DA_viaBatchNorm/configs/ganDA/train_log/feats/'
    for city in target:
        train = spacenet.Spacenet(city=city,
                                  split='train',
                                  img_root='/usr/xtmp/satellite/spacenet/')
        tbar = DataLoader(train, batch_size=16, shuffle=False, num_workers=2)
        cc = 0
        ct = 0
        for i, sample in enumerate(tbar):
            if True:
                image, gt = sample['image'], sample['label']
                activation = torch.load(
                    os.path.join(save_path, "{}Feats_{}.pt".format(
                        city[0], i)))  #map_location=torch.device('cpu')))
                assert image.shape[0] == activation.shape[0]
                for kk in range(image.shape[0]):
                    images.append(image[kk].cpu().numpy())
                    activations.append(activation[kk].cpu().numpy())
                    gts.append(gt[kk].cpu().numpy())
            cc += 1
        city_count.append(len(images))
        cities += [city] * city_count[ct]
        ct += 1

    return images, activations, gts, cities
Ejemplo n.º 8
0
def compute_mean_variance():
    train_set = spacenet.Spacenet(city=config.dataset,
                                  split='train',
                                  img_root=config.img_root)
    loader = DataLoader(train_set,
                        batch_size=config.batch_size,
                        shuffle=True,
                        num_workers=config.train_num_workers,
                        drop_last=True)
    mean = 0.
    std = 0.
    nb_samples = 0.
    for sample in tqdm(loader):
        data = sample['image']
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples

    print('mean: ', mean)
    print('std:', std)
Ejemplo n.º 9
0
def getImagesActivations():

    images = []
    activations = []
    gts = []
    cities = []
    city_count = []

    target = ['Vegas', 'Shanghai']
    save_path = '/home/home1/swarnakr/main/DomainAdaptation/DA_viaBatchNorm/configs/ganDA/train_log/feats/'
    for city in target:
        val = spacenet.Spacenet(city=city,
                                split='val',
                                img_root='/usr/xtmp/satellite/spacenet/')
        tbar = DataLoader(val, batch_size=16, shuffle=False, num_workers=2)
        cc = 0
        ct = 0
        for i, sample in enumerate(tbar):
            if True:
                image, gt = sample['image'], sample['label']
                for kk in range(image.shape[0]):
                    images.append(image[kk].cpu().numpy())
                    gts.append(gt[kk].cpu().numpy())
            cc += 1
        city_count.append(len(images))
        cities += [city] * city_count[ct]
        ct += 1

    return images, gts, cities
Ejemplo n.º 10
0
    def __init__(self, model_path, config, bn, save_path, save_batch, cuda=False):
        self.bn = bn
        self.target=config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set, batch_size=16, shuffle=False, num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            #test_img_root = '/home/home1/swarnakr/main/DomainAdaptation/satellite/' + city + '/' + 'test'
            #test = spacenet.Spacenet(city=city, split='test', img_root=test_img_root)
            test = spacenet.Spacenet(city=city, split='test', img_root=config.img_root)
            self.target_set.append(test)
            self.target_loader.append(DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
       #     train_img_root = '/home/home1/swarnakr/main/DomainAdaptation/satellite/' + city + '/' + 'train'
       #     train = spacenet.Spacenet(city=city, split='train', img_root=train_img_root)
            train = spacenet.Spacenet(city=city, split='train', img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                backbone=config.backbone,
                output_stride=config.out_stride,
                sync_bn=config.sync_bn,
                freeze_bn=config.freeze_bn)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        #print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()
Ejemplo n.º 11
0
def make_data_loader(config):

    train_set = spacenet.Spacenet(city=config.dataset,
                                  split='train',
                                  img_root=config.img_root)
    val_set = spacenet.Spacenet(city=config.dataset,
                                split='val',
                                img_root=config.img_root)
    #test_set = spacenet.Spacenet(city=config.dataset, split='test', img_root=config.img_root)

    train_loader = DataLoader(train_set,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.train_num_workers,
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.val_num_workers,
                            drop_last=True)
    #test_loader = DataLoader(test_set, batch_size=config.batch_size, shuffle=False, num_workers=config.val_num_workers)
    test_loader = None

    return train_loader, val_loader, test_loader
Ejemplo n.º 12
0
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.city = config.dataset  #all_dataset
        self.save_path = save_path
        self.save_batch = save_batch

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        if 1:  #for city in self.target:
            train = spacenet.Spacenet(city=self.city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        self.model = DeepLab(num_classes=2,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

        #if DA images
        self.checkpoint = torch.load(model_path)
        #'./train_log/' + self.config.dataset + '_da_' + city + '.pth')
        self.model.load_state_dict(self.checkpoint)
        if self.cuda:
            self.model = self.model.cuda()
Ejemplo n.º 13
0
        #grid_image = make_grid(torch.from_numpy(pred[:,None,:,:]), 5, normalize=False)
        #self.writer.add_image(prefix+'/Prediction', grid_image, global_step)
        #grid_image = make_grid(torch.from_numpy(target[:,None,:,:]), 5, normalize=False)
        #self.writer.add_image(prefix+'/GT', grid_image, global_step)
        grid_image = make_grid(torch.from_numpy(color_images(pred, target)),
                               5,
                               normalize=True,
                               range=(0, 255))
        self.writer.add_image(prefix + '/Color', grid_image, global_step)

        #grid_image = make_grid(images, 5, normalize=False)
        #self.writer.add_image(prefix+'/Image', grid_image, global_step)


if __name__ == "__main__":
    import sys
    sys.path.append(os.getcwd())
    from data import spacenet
    from torch.utils.data import DataLoader
    summary = TensorboardSummary('log')

    dataset = spacenet.Spacenet('Vegas', source_dist=dist['Vegas'])
    loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=2)
    for i, sample in enumerate(loader):
        image, target = sample['image'], sample['label']
        summary.visualize_image('train', 'Vegas', image, target,
                                image[:, :2, :, :], i)
        if i == 3:
            break
Ejemplo n.º 14
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        # if args.restore_from[:4] == 'http' :
        #     saved_state_dict = model_zoo.load_url(args.restore_from)
        # else:
        #     saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        # for i in saved_state_dict:
        #     # Scale.layer5.conv2d_list.3.weight
        #     i_parts = i.split('.')
        #     # print i_parts
        #     if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #         new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #         # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # trainloader = data.DataLoader(
    #     GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                 crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
    #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    train_set = spacenet.Spacenet(city=config.dataset, split='train', img_root=config.img_root)
    trainloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=True)

    trainloader_iter = enumerate(trainloader)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
    #                                                  set=args.set),
    #                                batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
    #                                pin_memory=True)

    target_set = spacenet.Spacenet(city=config.target, split='train', img_root=config.img_root)
    targetloader = DataLoader(target_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=True)
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except StopIteration:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()
            images, labels = batch
            # print(images)
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # loss_seg_value1 += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            # loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            try:
                _, batch = targetloader_iter.__next__()
            except StopIteration:
                targetloader_iter = enumerate(targetloader)
                _, batch = targetloader_iter.__next__()
            images, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            # loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy()[0] / args.iter_size
            # loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy()[0] / args.iter_size
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D2.pth'))