Exemple #1
0
def main(args):
    """Create the model and start the evaluation process."""

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(ListDataSet(args.data_dir,
                                             args.img_list,
                                             args.lbl_list,
                                             crop_size=(1024, 512),
                                             mean=IMG_MEAN,
                                             split=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048),
                         mode='bilinear',
                         align_corners=True)

    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, _, _, name = batch
            if args.model == 'DeeplabMulti':
                output1, output2 = model(Variable(image).cuda(gpu0))
                output = interp(output2).cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG':
                output = model(Variable(image).cuda(gpu0))
                output = interp(output).cpu().data[0].numpy()

            output = output.transpose(1, 2, 0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name = name[0].split('/')[-1]
            output.save('%s/%s' % (args.save, name))
            output_col.save('%s/%s_color.png' %
                            (args.save, name.split('.')[0]))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    criterion = DiceBCELoss()
    # criterion = nn.CrossEntropyLoss(ignore_index=253)
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from is None:
            pass
        elif args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from is not None:
            saved_state_dict = torch.load(args.restore_from)

        if args.restore_from is not None:
            new_params = model.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    # print i_parts
            model.load_state_dict(new_params)

    if not args.no_logging:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SyntheticSmokeTrain(
        args={},
        dataset_limit=args.num_steps * args.iter_size * args.batch_size,
        image_shape=input_size,
        dataset_mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    print("Length of train dataloader: ", len(trainloader))
    targetloader = data.DataLoader(SimpleSmokeVal(args={},
                                                  image_size=input_size_target,
                                                  dataset_mean=IMG_MEAN),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    print("Length of train dataloader: ", len(targetloader))
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
        # bce_loss_all = torch.nn.BCEWithLogitsLoss(reduction='none')
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()
        # bce_loss_all = torch.nn.MSELoss(reduction='none')

    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    # interp_domain = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            for param in interp_domain.parameters():
                param.requires_grad = False

            # train with source
            # try:
            _, batch = next(trainloader_iter)  #.next()
            # except StopIteration:
            # trainloader = data.DataLoader(
            #     SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size,
            #                 image_shape=input_size, dataset_mean=IMG_MEAN),
            #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
            # trainloader_iter = iter(trainloader)
            # _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            # print("Shape of labels", labels.shape)
            # print("Are labels all zero? ")
            # for i in range(labels.shape[0]):
            #     print("{}: All zero? {}".format(i, torch.all(labels[i]==0)))
            #     print("{}: All 255? {}".format(i, torch.all(labels[i]==255)))
            #     print("{}: Mean = {}".format(i, torch.mean(labels[i])))

            pred1, pred2 = model(images)
            # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape))
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape))
            # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']):
            #     print(name)
            #     for i in range(pred.shape[0]):
            #         print("{}: All zero? {}".format(i, torch.all(pred[i]==0)))
            #         print("{}: All 255? {}".format(i, torch.all(pred[i]==255)))
            #         print("{}: Mean = {}".format(i, torch.mean(pred[i])))

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print("Seg1 loss: ",loss_seg1, args.iter_size)
            # print("Seg2 loss: ",loss_seg2, args.iter_size)

            loss_seg_value1 += loss_seg1.detach().data.cpu().item(
            ) / args.iter_size
            loss_seg_value2 += loss_seg2.detach().data.cpu().item(
            ) / args.iter_size

            # train with target
            # try:
            _, batch = next(targetloader_iter)  #.next()
            # except StopIteration:
            #     targetloader = data.DataLoader(
            #         SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN),
            #                         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            #                         pin_memory=True)
            #     targetloader_iter = iter(targetloader)
            #     _, batch = next(targetloader_iter)

            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            # w1 = torch.argmax(pred_target1.detach(), dim=1)
            # w2 = torch.argmax(pred_target2.detach(), dim=1)

            min_class1 = sorted([(k, v)
                                 for k, v in Counter(w1.ravel()).items()],
                                key=lambda x: x[1])[0][0]
            min_class2 = sorted([(k, v)
                                 for k, v in Counter(w2.ravel()).items()],
                                key=lambda x: x[1])[0][0]

            # m1 = torch.where(w1==min_class1)
            # m1c = torch.where(w1!=min_class1)
            # w1[m1] = 11
            # w1[m1c] = 1

            # m2 = torch.where(w2==min_class2)
            # m2c = torch.where(w2!=min_class2)
            # w2[m2] = 11
            # w2[m2c] = 1

            # D_out1 = interp_domain(D_out1)
            # D_out2 = interp_domain(D_out2)

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.detach().data.cpu(
            ).item() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.detach().data.cpu(
            ).item() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.detach().data.cpu().item()
            loss_D_value2 += loss_D2.detach().data.cpu().item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))
        writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter)
        writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'lmda_adv_0.1_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'lmda_adv_0.1_' + str(i_iter) + '_D2.pth'))
        writer.flush()
Exemple #3
0
class Test:
    def __init__(self,
                 model_path,
                 config,
                 bn,
                 save_path,
                 save_batch,
                 cuda=False):
        self.bn = bn
        self.target = config.all_dataset
        self.target.remove(config.dataset)
        # load source domain
        self.source_set = spacenet.Spacenet(city=config.dataset,
                                            split='test',
                                            img_root=config.img_root)
        self.source_loader = DataLoader(self.source_set,
                                        batch_size=16,
                                        shuffle=False,
                                        num_workers=2)

        self.save_path = save_path
        self.save_batch = save_batch

        self.target_set = []
        self.target_loader = []

        self.target_trainset = []
        self.target_trainloader = []

        self.config = config

        # load other domains
        for city in self.target:
            test = spacenet.Spacenet(city=city,
                                     split='test',
                                     img_root=config.img_root)
            self.target_set.append(test)
            self.target_loader.append(
                DataLoader(test, batch_size=16, shuffle=False, num_workers=2))
            train = spacenet.Spacenet(city=city,
                                      split='train',
                                      img_root=config.img_root)
            self.target_trainset.append(train)
            self.target_trainloader.append(
                DataLoader(train, batch_size=16, shuffle=False, num_workers=2))

        # self.model = DeepLab(num_classes=2,
        #         backbone=config.backbone,
        #         output_stride=config.out_stride,
        #         sync_bn=config.sync_bn,
        #         freeze_bn=config.freeze_bn)
        self.model = DeeplabMulti(num_classes=2)
        if cuda:
            self.checkpoint = torch.load(model_path)
        else:
            self.checkpoint = torch.load(model_path,
                                         map_location=torch.device('cpu'))
        # print(self.checkpoint.keys())
        self.model.load_state_dict(self.checkpoint)
        self.evaluator = Evaluator(2)
        self.cuda = cuda
        if cuda:
            self.model = self.model.cuda()

    def get_performance(self, dataloader, trainloader, city):
        # change mean and var of bn to adapt to the target domain
        if self.bn and city != self.config.dataset:
            print('BN Adaptation on' + city)
            self.model.train()
            for sample in trainloader:
                image, target = sample['image'], sample['label']
                if self.cuda:
                    image, target = image.cuda(), target.cuda()
                with torch.no_grad():
                    output = self.model(image)

        batch = self.save_batch
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(dataloader, desc='\r')

        # save in different directories
        if self.bn:
            save_path = os.path.join(self.save_path, city + '_bn')
        else:
            save_path = os.path.join(self.save_path, city)

        interp = nn.Upsample(size=(400, 400), mode='bilinear')

        # evaluate on the test dataset
        for i, sample in enumerate(tbar):
            # image, target = sample['image'], sample['label']
            image, target = sample
            if self.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                _, output = self.model(image)
            output = interp(output)
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

            # save pictures
            if batch > 0:
                if not os.path.exists(self.save_path):
                    os.mkdir(self.save_path)
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                image = image.cpu().numpy() * 255
                image = image.transpose(0, 2, 3, 1).astype(int)

                imgs = self.color_images(pred, target)
                self.save_images(imgs, batch, save_path, False)
                self.save_images(image, batch, save_path, True)
                batch -= 1

        Acc = self.evaluator.Building_Acc()
        IoU = self.evaluator.Building_IoU()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        return Acc, IoU, mIoU

    def test(self):
        A, I, Im = self.get_performance(self.source_loader, None,
                                        self.config.dataset)
        tA, tI, tIm = [], [], []
        for dl, tl, city in zip(self.target_loader, self.target_trainloader,
                                self.target):
            tA_, tI_, tIm_ = self.get_performance(dl, tl, city)
            tA.append(tA_)
            tI.append(tI_)
            tIm.append(tIm_)

        res = {}
        print("Test for source domain:")
        print("{}: Acc:{}, IoU:{}, mIoU:{}".format(self.config.dataset, A, I,
                                                   Im))
        res[config.dataset] = {'Acc': A, 'IoU': I, 'mIoU': Im}

        print('Test for target domain:')
        for i, city in enumerate(self.target):
            print("{}: Acc:{}, IoU:{}, mIoU:{}".format(city, tA[i], tI[i],
                                                       tIm[i]))
            res[city] = {'Acc': tA[i], 'IoU': tI[i], 'mIoU': tIm[i]}

        if self.bn:
            name = 'train_log/test_bn.json'
        else:
            name = 'train_log/test.json'

        with open(name, 'w') as f:
            json.dump(res, f)

    def save_images(self, imgs, batch_index, save_path, if_original=False):
        for i, img in enumerate(imgs):
            img = img[:, :, ::-1]  # change to BGR
            # from IPython import embed
            # embed()
            if not if_original:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Original.jpg'),
                    img)
            else:
                cv2.imwrite(
                    os.path.join(save_path,
                                 str(batch_index) + str(i) + '_Pred.jpg'), img)

    def color_images(self, pred, target):
        imgs = []
        for p, t in zip(pred, target):
            tmp = p * 2 + t
            np.squeeze(tmp)
            img = np.zeros((p.shape[0], p.shape[1], 3))
            # bkg:negative, building:postive
            # from IPython import embed
            # embed()
            img[np.where(tmp == 0)] = [0, 0, 0]  # Black RGB, for true negative
            img[np.where(tmp == 1)] = [255, 0,
                                       0]  # Red RGB, for false negative
            img[np.where(tmp == 2)] = [0, 255,
                                       0]  # Green RGB, for false positive
            img[np.where(tmp == 3)] = [255, 255,
                                       0]  # Yellow RGB, for true positive
            imgs.append(img)
        return imgs
Exemple #4
0
def main():

    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # create network
    model = DeeplabMulti(num_classes=args.num_classes)
    #model = Res_Deeplab(num_classes=args.num_classes)

    # load pretrained parameters
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from, map_location='cuda:0')

    # only copy the params that exist in current model (caffe-like)
    new_params = model.state_dict().copy()
    for name, param in new_params.items():
        print(name)
        if name in saved_state_dict and param.size(
        ) == saved_state_dict[name].size():
            new_params[name].copy_(saved_state_dict[name])
            print('copy {}'.format(name))
    model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)
    #summary(model,(3,7,7))

    cudnn.benchmark = True

    # init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    if args.restore_from_D is not None:
        model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cuda(args.gpu)
    #summary(model_D, (21,321,321))
    #quit()

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                      args.iter_size * args.batch_size,
                                      scale=args.random_scale)
    train_dataset_size = len(train_dataset)
    train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps *
                                         args.iter_size * args.batch_size,
                                         scale=args.random_scale)

    if args.partial_data is None:
        trainloader = data.DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=5,
                                         pin_memory=True)
    else:
        # sample partial data
        partial_size = int(args.partial_data * train_dataset_size)

        if args.partial_id is not None:
            train_ids = pickle.load(open(args.partial_id))
            print('loading train ids from {}'.format(args.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(train_ids,
                    open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = data.DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
        trainloader_remain = data.DataLoader(train_dataset,
                                             sampler=train_remain_sampler,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)
        trainloader_gt = data.DataLoader(train_gt_dataset,
                                         sampler=train_gt_sampler,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True)
        trainloader_remain_iter = enumerate(trainloader)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # implement model.optim_parameters(args) to handle different models' lr setting

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(input_size[1], input_size[0]),
                             mode='bilinear')

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            if (args.lambda_semi > 0 or args.lambda_semi_adv > 0
                ) and i_iter >= args.semi_start_adv:
                try:
                    _, batch = trainloader_remain_iter.__next__()
                except:
                    trainloader_remain_iter = enumerate(trainloader_remain)
                    _, batch = trainloader_remain_iter.__next__()

                # only access to img
                images, _, _, _ = batch
                images = Variable(images).cuda(args.gpu)

                try:
                    pred = interp(model(images))
                except RuntimeError as exception:
                    if "out of memory" in str(exception):
                        print("WARNING: out of memory")
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        raise exception

                pred_remain = pred.detach()

                D_out = interp(model_D(F.softmax(pred)))
                D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(
                    axis=1)

                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)

                loss_semi_adv = args.lambda_semi_adv * bce_loss(
                    D_out, make_D_label(gt_label, ignore_mask_remain))
                loss_semi_adv = loss_semi_adv / args.iter_size

                #loss_semi_adv.backward()
                loss_semi_adv_value += loss_semi_adv.data.cpu().numpy(
                ) / args.lambda_semi_adv

                if args.lambda_semi <= 0 or i_iter < args.semi_start:
                    loss_semi_adv.backward()
                    loss_semi_value = 0
                else:
                    # produce ignore mask
                    semi_ignore_mask = (D_out_sigmoid < args.mask_T)

                    semi_gt = pred.data.cpu().numpy().argmax(axis=1)
                    semi_gt[semi_ignore_mask] = 255

                    semi_ratio = 1.0 - float(
                        semi_ignore_mask.sum()) / semi_ignore_mask.size
                    print('semi ratio: {:.4f}'.format(semi_ratio))

                    if semi_ratio == 0.0:
                        loss_semi_value += 0
                    else:
                        semi_gt = torch.FloatTensor(semi_gt)

                        loss_semi = args.lambda_semi * loss_calc(
                            pred, semi_gt, args.gpu)
                        loss_semi = loss_semi / args.iter_size
                        loss_semi_value += loss_semi.data.cpu().numpy(
                        ) / args.lambda_semi
                        loss_semi += loss_semi_adv
                        loss_semi.backward()

            else:
                loss_semi = None
                loss_semi_adv = None

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()

            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            try:
                pred = interp(model(images))
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception

            loss_seg = loss_calc(pred, labels, args.gpu)

            D_out = interp(model_D(F.softmax(pred)))

            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            if args.D_remain:
                pred = torch.cat((pred, pred_remain), 0)
                ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain),
                                             axis=0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = trainloader_gt_iter.__next__()
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = trainloader_gt_iter.__next__()

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'CITY_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '_D.pth'))
            #torch.cuda.empty_cache()

    end = timeit.default_timer()
    print(end - start, 'seconds')
def train(gpu, args):
    """Create the model and start the training."""

    rank = args.nr * args.num_gpus + gpu
    if gpu == 1:
        gpu = 3

    dist.init_process_group(backend="nccl",
                            world_size=args.world_size,
                            rank=rank)

    if args.batch_size == 1 and args.use_bn is True:
        raise Exception

    torch.autograd.set_detect_anomaly(True)
    torch.manual_seed(args.torch_seed)
    torch.cuda.manual_seed(args.cuda_seed)

    torch.cuda.set_device(gpu)

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = gpu

    criterion = DiceBCELoss()
    # criterion = nn.CrossEntropyLoss(ignore_index=253)
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from is None:
            pass
        elif args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from is not None:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict)
            print("Loaded state dicts for model")
        # if args.restore_from is not None:
        #     new_params = model.state_dict().copy()
        #     for i in saved_state_dict:
        #         # Scale.layer5.conv2d_list.3.weight
        #         i_parts = i.split('.')
        #         # print i_parts
        #         if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #             new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #             # print i_parts
        #     model.load_state_dict(new_params)

    if not args.no_logging:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    model.train()
    # model.cuda(gpu)
    model = model.cuda(device=gpu)

    if args.num_gpus > 0 or torch.cuda.device_count() > 0:
        model = DistributedDataParallel(model,
                                        device_ids=[gpu],
                                        find_unused_parameters=True)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)
    start_epoch = 0
    if "http" not in args.restore_from and args.restore_from is not None:
        root, extension = args.restore_from.strip().split(".")
        D1pth = root + "_D1." + extension
        D2pth = root + "_D2." + extension
        saved_state_dict = torch.load(D1pth)
        model_D1.load_state_dict(saved_state_dict)
        saved_state_dict = torch.load(D2pth)
        model_D2.load_state_dict(saved_state_dict)
        start_epoch = int(re.findall(r'[\d]+', root)[-1])
        print("Loaded state dict for models D1 and D2")

    model_D1.train()
    # model_D1.cuda(gpu)
    model_D2.train()
    # model_D2.cuda(gpu)

    model_D1 = model_D1.cuda(device=gpu)
    model_D2 = model_D2.cuda(device=gpu)

    if args.num_gpus > 0 or torch.cuda.device_count() > 0:
        model_D1 = DistributedDataParallel(model_D1,
                                           device_ids=[gpu],
                                           find_unused_parameters=True)
        model_D2 = DistributedDataParallel(model_D2,
                                           device_ids=[gpu],
                                           find_unused_parameters=True)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = SyntheticSmokeTrain(args={},
                                        dataset_limit=args.num_steps *
                                        args.iter_size * args.batch_size,
                                        image_shape=input_size,
                                        dataset_mean=IMG_MEAN)

    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=args.world_size,
                                       rank=rank,
                                       shuffle=True)
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  sampler=train_sampler)

    # trainloader = data.DataLoader(
    #     GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                 crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
    #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    print("Length of train dataloader: ", len(trainloader))
    target_dataset = SimpleSmokeVal(args={},
                                    image_size=input_size_target,
                                    dataset_mean=IMG_MEAN)
    target_sampler = DistributedSampler(target_dataset,
                                        num_replicas=args.world_size,
                                        rank=rank,
                                        shuffle=True)
    targetloader = data.DataLoader(target_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   sampler=target_sampler)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
    #                                                  set=args.set),
    #                                batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
    #                                pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    print("Length of train dataloader: ", len(targetloader))
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(start_epoch, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source
            # try:
            _, batch = next(trainloader_iter)  #.next()
            # except StopIteration:
            # trainloader = data.DataLoader(
            #     SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size,
            #                 image_shape=input_size, dataset_mean=IMG_MEAN),
            #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
            # trainloader_iter = iter(trainloader)
            # _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(gpu)
            # print("Shape of labels", labels.shape)
            # print("Are labels all zero? ")
            # for i in range(labels.shape[0]):
            #     print("{}: All zero? {}".format(i, torch.all(labels[i]==0)))
            #     print("{}: All 255? {}".format(i, torch.all(labels[i]==255)))
            #     print("{}: Mean = {}".format(i, torch.mean(labels[i])))

            pred1, pred2 = model(images)
            # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape))
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape))
            # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']):
            #     print(name)
            #     for i in range(pred.shape[0]):
            #         print("{}: All zero? {}".format(i, torch.all(pred[i]==0)))
            #         print("{}: All 255? {}".format(i, torch.all(pred[i]==255)))
            #         print("{}: Mean = {}".format(i, torch.mean(pred[i])))

            loss_seg1 = loss_calc(pred1, labels, gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print("Seg1 loss: ",loss_seg1, args.iter_size)
            # print("Seg2 loss: ",loss_seg2, args.iter_size)

            loss_seg_value1 += loss_seg1.data.cpu().item() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().item() / args.iter_size

            # train with target
            # try:
            _, batch = next(targetloader_iter)  #.next()
            # except StopIteration:
            #     targetloader = data.DataLoader(
            #         SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN),
            #                         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            #                         pin_memory=True)
            #     targetloader_iter = iter(targetloader)
            #     _, batch = next(targetloader_iter)

            images, _, _ = batch
            images = Variable(images).cuda(gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().item(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().item(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().item()
            loss_D_value2 += loss_D2.data.cpu().item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().item()
            loss_D_value2 += loss_D2.data.cpu().item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))
        writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter)
        writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D2.pth'))
        writer.flush()
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print '%d processd' % index
        image, _, name = batch
        if args.model == 'DeeplabMulti':
            output1, output2 = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output2).cpu().data[0].numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output = model(Variable(image, volatile=True).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

        output_col = colorize_mask(output)
        output = Image.fromarray(output)

        name = name[0].split('/')[-1]
        output.save('%s/%s' % (args.save, name))
        output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))
def main():
    """Create the model and start the evaluation process."""
    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    #model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(512, 1024),
                                                   resize_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)
    scale = 0.9
    testloader3 = data.DataLoader(cityscapesDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(512 * scale), round(1024 * scale)),
        resize_size=(round(1024 * scale), round(512 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')

    for index, img_data in enumerate(zip(testloader, testloader2,
                                         testloader3)):
        batch, batch2, batch3 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        #image3, _, _, name3 = batch3

        inputs = image.cuda()
        inputs2 = image2.cuda()
        #inputs3 = Variable(image3).cuda()
        print('\r>>>>Extracting feature...%03d/%03d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = output1, output2
                #output_batch = interp(sm(output1))
                #output_batch = interp(sm(output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                heatmap_output1, heatmap_output2 = heatmap_output1 + output1, heatmap_output2 + output2
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                #output_batch += interp(sm(output1))
                #output_batch += interp(sm(output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
                heatmap_batch = torch.sum(kl_distance(log_sm(heatmap_output1),
                                                      sm(heatmap_output2)),
                                          dim=1)
                heatmap_batch = torch.log(
                    1 + 10 * heatmap_batch)  # for visualization
                heatmap_batch = heatmap_batch.cpu().data.numpy()

                #output1, output2 = model(inputs3)
                #output_batch += interp(sm(0.5* output1 + output2)).cpu().data.numpy()
                #output1, output2 = model(fliplr(inputs3))
                #output1, output2 = fliplr(output1), fliplr(output2)
                #output_batch += interp(sm(0.5 * output1 + output2)).cpu().data.numpy()
                #del output1, output2, inputs3
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        scoremap_batch = np.asarray(np.max(output_batch, axis=3))
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        output_iterator = []
        heatmap_iterator = []
        scoremap_iterator = []

        for i in range(output_batch.shape[0]):
            output_iterator.append(output_batch[i, :, :])
            heatmap_iterator.append(heatmap_batch[i, :, :] /
                                    np.max(heatmap_batch[i, :, :]))
            scoremap_iterator.append(1 - scoremap_batch[i, :, :] /
                                     np.max(scoremap_batch[i, :, :]))
            name_tmp = name[i].split('/')[-1]
            name[i] = '%s/%s' % (args.save, name_tmp)
        with Pool(4) as p:
            p.map(save, zip(output_iterator, name))
            p.map(save_heatmap, zip(heatmap_iterator, name))
            p.map(save_scoremap, zip(scoremap_iterator, name))

        del output_batch

    return args.save
class AD_Trainer(nn.Module):
    def __init__(self, args):
        super(AD_Trainer, self).__init__()
        self.fp16 = args.fp16
        self.class_balance = args.class_balance
        self.often_balance = args.often_balance
        self.num_classes = args.num_classes
        self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.class_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.often_weight_t = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        self.multi_gpu = args.multi_gpu
        self.only_hard_label = args.only_hard_label
        if args.model == 'DeepLab':
            self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate)
            if args.restore_from[:4] == 'http' :
                saved_state_dict = model_zoo.load_url(args.restore_from)
            else:
                saved_state_dict = torch.load(args.restore_from)

            new_params = self.G.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if args.restore_from[:4] == 'http' :
                    if i_parts[1] !='fc' and i_parts[1] !='layer5':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[1:])
                else:
                    #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    if i_parts[0] =='module':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[1:])
                    else:
                        new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n'%i_parts[0:])
        self.G.load_state_dict(new_params)

        self.D1 = MsImageDis(input_dim = args.num_classes).cuda() 
        self.D2 = MsImageDis(input_dim = args.num_classes).cuda() 
        self.D1.apply(weights_init('gaussian'))
        self.D2.apply(weights_init('gaussian'))

        if self.multi_gpu and args.sync_bn:
            print("using apex synced BN")
            self.G = apex.parallel.convert_syncbn_model(self.G)

        self.gen_opt = optim.SGD(self.G.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay)

        self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))

        self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))

        self.seg_loss = nn.CrossEntropyLoss(ignore_index=255)
        self.kl_loss = nn.KLDivLoss(size_average=False)
        self.sm = torch.nn.Softmax(dim = 1)
        self.log_sm = torch.nn.LogSoftmax(dim = 1)
        self.G = self.G.cuda()
        self.D1 = self.D1.cuda()
        self.D2 = self.D2.cuda()
        self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True)
        self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True)
        self.lambda_seg = args.lambda_seg
        self.max_value = args.max_value
        self.lambda_me_target = args.lambda_me_target
        self.lambda_kl_target = args.lambda_kl_target
        self.lambda_adv_target1 = args.lambda_adv_target1
        self.lambda_adv_target2 = args.lambda_adv_target2
        self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        if args.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1")
            self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1")
            self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1")

    def update_class_criterion(self, labels):
            weight = torch.FloatTensor(self.num_classes).zero_().cuda()
            weight += 1
            count = torch.FloatTensor(self.num_classes).zero_().cuda()
            often = torch.FloatTensor(self.num_classes).zero_().cuda()
            often += 1
            n, h, w = labels.shape
            for i in range(self.num_classes):
                count[i] = torch.sum(labels==i)
                if count[i] < 64*64*n: #small objective, original train size is 512*256
                    weight[i] = self.max_value
            if self.often_balance:
                often[count == 0] = self.max_value

            self.often_weight = 0.9 * self.often_weight + 0.1 * often 
            self.class_weight = weight * self.often_weight
            print(self.class_weight)
            return nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255)

    def update_class_criterion_t(self, labels):
            weight = torch.FloatTensor(self.num_classes).zero_().cuda()
            weight += 1
            count = torch.FloatTensor(self.num_classes).zero_().cuda()
            often = torch.FloatTensor(self.num_classes).zero_().cuda()
            often += 1
            n, h, w = labels.shape
            for i in range(self.num_classes):
                count[i] = torch.sum(labels==i)
                if count[i] < 64*64*n: #small objective, original train size is 512*256
                    weight[i] = self.max_value
            if self.often_balance:
                often[count == 0] = self.max_value

            self.often_weight_t = 0.9 * self.often_weight_t + 0.1 * often 
            self.class_weight_t = weight * self.often_weight_t
            print(self.class_weight_t)
            return nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255)

    def update_label(self, labels, prediction):
            criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none')
            #criterion = self.seg_loss
            loss = criterion(prediction, labels)
            print('original loss: %f'% self.seg_loss(prediction, labels) )
            #mm = torch.median(loss)
            loss_data = loss.data.cpu().numpy()
            mm = np.percentile(loss_data[:], self.only_hard_label)
            #print(m.data.cpu(), mm)
            labels[loss < mm] = 255
            return labels

    def update_variance(self, labels, pred1, pred2):
            criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none')
            kl_distance = nn.KLDivLoss( reduction = 'none')
            loss = criterion(pred1, labels)

            variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) 
            exp_variance = torch.exp(-variance)

            print(variance.shape)
            print('variance mean: %.4f'%torch.mean(exp_variance[:]))
            print('variance min: %.4f'%torch.min(exp_variance[:]))
            print('variance max: %.4f'%torch.max(exp_variance[:]))
            loss = torch.mean(loss*exp_variance) + torch.mean(variance)
            return loss

    def update_variance_t(self, labels, pred1, pred2):
            criterion = nn.CrossEntropyLoss(weight = self.class_weight_t, ignore_index=255, reduction = 'none')
            kl_distance = nn.KLDivLoss( reduction = 'none')
            loss = criterion(pred1, labels)

            variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) 
            exp_variance = torch.exp(-variance)

            print(variance.shape)
            print('variance mean: %.4f'%torch.mean(exp_variance[:]))
            print('variance min: %.4f'%torch.min(exp_variance[:]))
            print('variance max: %.4f'%torch.max(exp_variance[:]))
            loss = torch.mean(loss*exp_variance) + torch.mean(variance)
            return loss

    def update_loss(self, loss):
        if self.fp16:
            with amp.scale_loss(loss, self.gen_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

    def gen_update(self, images, images_t, labels, labels_t, i_iter):
            self.gen_opt.zero_grad()

            pred1, pred2 = self.G(images)
            pred1 = self.interp(pred1)
            pred2 = self.interp(pred2)

            if self.class_balance:            
                self.seg_loss = self.update_class_criterion(labels)

            # calculate seg loss weighted by kldivloss
            # loss_seg1 = self.update_variance(labels, pred1, pred2)
            # loss_seg2 = self.update_variance(labels, pred2, pred1)
            loss_seg1 = self.seg_loss(pred1, labels)
            loss_seg2 = self.seg_loss(pred2, labels)
 
            loss = loss_seg2 + self.lambda_seg * loss_seg1

            self.update_loss(loss)
            images_t = images_t.cuda()
            labels_t = labels_t.long().cuda()
            pred1_t, pred2_t = self.G(images_t)
            pred1_t = self.interp(pred1_t)
            pred2_t = self.interp(pred2_t)

            if self.class_balance:            
                self.seg_loss_t = self.update_class_criterion_t(labels_t)

            # calculate seg loss weighted by kldivloss
            loss_seg1_t = self.update_variance_t(labels_t, pred1_t, pred2_t)
            loss_seg2_t = self.update_variance_t(labels_t, pred2_t, pred1_t)
 
            loss = loss_seg2_t + self.lambda_seg * loss_seg1_t
            self.update_loss(loss)

            self.gen_opt.step()
            zero_loss = torch.zeros(1).cuda()
            return loss_seg1, loss_seg2, loss_seg1_t, loss_seg2_t, zero_loss, zero_loss, zero_loss, zero_loss, pred1, pred2, None, None
    
    def dis_update(self, pred1, pred2, pred_target1, pred_target2):
            self.dis1_opt.zero_grad()
            self.dis2_opt.zero_grad()
            pred1 = pred1.detach()
            pred2 = pred2.detach()
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            if self.multi_gpu:
                loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )
                loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )
            else:
                loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )
                loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) )

            loss = loss_D1 + loss_D2
            if self.fp16:
                with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            self.dis1_opt.step()
            self.dis2_opt.step()
            return loss_D1, loss_D2
def main():
    """Create the model and start the training."""

    gpu_id_2 = 1
    gpu_id_1 = 0

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(
                '/data2/zhangjunyi/snapshots/snapshots_syn/onlysyn/GTA5_80000.pth',
                map_location={"cuda:3": "cuda:0"})

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if (not i_parts[1] == 'layer5') and (not i_parts[0] == 'fc'):
                new_params['.'.join(i_parts)] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes, match=False)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(SynthiaDataSet(
        args.data_dir,
        args.data_list,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size,
        scale=args.random_scale,
        mirror=args.random_mirror,
        mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.__next__()

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        cut_size=cut_size,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.__next__()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    mse_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size[1], input_size[0]),
                                         mode='bilinear',
                                         align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size_target[1],
                                               input_size_target[0]),
                                         mode='bilinear',
                                         align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        number1 = 0
        number2 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            beta = args.beta
            if i_iter == 0:
                print(beta)
            _, batch = trainloader_iter.__next__()
            _, batch_target = targetloader_iter.__next__()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta)
            labels = new_labels
            number1 = torch.sum(labels == 255).item()
            loss_seg2, new_labels = loss_calc(pred2, labels, gpu_id_1, beta)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size

            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1, new_labels = loss_calc(pred1, labels, gpu_id_1, beta)
            labels = new_labels
            number2 = torch.sum(labels == 255).item()
            loss_seg2, new_lables = loss_calc(pred2, labels, gpu_id_1, beta)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            pred1_last_target, pred2_last_target, labels_last_target = result_model(
                batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)
            # exit()

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D1(fake2_D)

            loss_adv_fake1 = mse_loss(
                D_out_fake_1,
                Variable(
                    torch.FloatTensor(D_out_fake_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_fake2 = mse_loss(
                D_out_fake_2,
                Variable(
                    torch.FloatTensor(D_out_fake_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()

            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)

            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D1(mix2_D)

            loss_adv_mix1 = mse_loss(
                D_out_mix_1,
                Variable(
                    torch.FloatTensor(D_out_mix_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_mix2 = mse_loss(
                D_out_mix_2,
                Variable(
                    torch.FloatTensor(D_out_mix_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_mix1 * 2
            loss_adv_target2 = loss_adv_mix2 * 2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value1 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D2
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            D_out_target_1 = model_D2(pred1_target_D)
            D_out_target_2 = model_D2(pred2_target_D)

            loss_adv_target1 = bce_loss(
                D_out_target_1,
                Variable(
                    torch.FloatTensor(D_out_target_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target2 = bce_loss(
                D_out_target_2,
                Variable(
                    torch.FloatTensor(D_out_target_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(
                batch_last, interp)

            # train with source

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D1(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D1(mix2_D_)

            loss_D1 = mse_loss(
                D_out1_real,
                Variable(
                    torch.FloatTensor(D_out1_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = mse_loss(
                D_out2_real,
                Variable(
                    torch.FloatTensor(D_out2_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = mse_loss(
                D_out1_mix,
                Variable(
                    torch.FloatTensor(D_out1_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D4 = mse_loss(
                D_out2_mix,
                Variable(
                    torch.FloatTensor(D_out2_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value1 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_last_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_last_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D1(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D1(mix2_D__)

            loss_D1 = mse_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D2 = mse_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D3 = mse_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = mse_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value1 += loss_D2.data.cpu().numpy()

            # train model-D2
            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            D_out1 = model_D2(pred1_D)
            D_out2 = model_D2(pred2_D)
            D_out3 = model_D2(pred1_target_D)
            D_out4 = model_D2(pred2_target_D)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value2 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, number1 = {8}, number2 = {9}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, number1, number2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    confidence_path = os.path.join(args.save, 'submit/confidence')
    label_path = os.path.join(args.save, 'submit/labelTrainIds')
    label_invalid_path = os.path.join(args.save,
                                      'submit/labelTrainIds_invalid')
    for path in [confidence_path, label_path, label_invalid_path]:
        if not os.path.exists(path):
            os.makedirs(path)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(DarkZurichDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(h, w),
                                                   resize_size=(w, h),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(DarkZurichDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(h * scale), round(w * scale)),
        resize_size=(round(w * scale), round(h * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(1080, 1920),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(1080, 1920), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    log_sm = torch.nn.LogSoftmax(dim=1)
    kl_distance = nn.KLDivLoss(reduction='none')
    prior = np.load('./utils/prior_all.npy').transpose(
        (2, 0, 1))[np.newaxis, :, :, :]
    prior = torch.from_numpy(prior)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, name = batch
        image2, _, name2 = batch2

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, args.batchsize * len(testloader)),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))

                heatmap_batch = torch.sum(kl_distance(log_sm(output1),
                                                      sm(output2)),
                                          dim=1)

                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                ratio = 0.95
                output_batch = output_batch.cpu() / 4
                # output_batch = output_batch *(ratio + (1 - ratio) * prior)
                output_batch = output_batch.data.numpy()
                heatmap_batch = heatmap_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        threshold = 0.3274
        for i in range(output_batch.shape[0]):
            output_single = output_batch[i, :, :]
            output_col = colorize_mask(output_single)
            output = Image.fromarray(output_single)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

            # heatmap_tmp = heatmap_batch[i,:,:]/np.max(heatmap_batch[i,:,:])
            # fig = plt.figure()
            # plt.axis('off')
            # heatmap = plt.imshow(heatmap_tmp, cmap='viridis')
            # fig.colorbar(heatmap)
            # fig.savefig('%s/%s_heatmap.png' % (save_path, name_tmp.split('.')[0]))

            if args.set == 'test' or args.set == 'val':
                # label
                output.save('%s/%s' % (label_path, name_tmp))
                # label invalid
                output_single[score_batch[i, :, :] < threshold] = 255
                output = Image.fromarray(output_single)
                output.save('%s/%s' % (label_invalid_path, name_tmp))
                # conficence

                confidence = score_batch[i, :, :] * 65535
                confidence = np.asarray(confidence, dtype=np.uint16)
                print(confidence.min(), confidence.max())
                iio.imwrite('%s/%s' % (confidence_path, name_tmp), confidence)

    return args.save
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    def _forward_single(_iter):
        _scalar = torch.tensor(1.).cuda().requires_grad_()

        _, batch = _iter.__next__()
        images, labels, _, _ = batch
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model.forward_irm(images, _scalar)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        # compute penalty
        grad = autograd.grad(loss, [_scalar], create_graph=True)[0]
        penalty = torch.sum(grad**2)

        loss += args.lambda_irm * penalty

        loss.backward()

        return loss_seg1, loss_seg2

    for i_iter in range(args.num_steps):

        loss_seg1_src = 0
        loss_seg2_src = 0

        loss_seg1_tgt = 0
        loss_seg2_tgt = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        for sub_i in range(args.iter_size):

            # def _forward_single(_loader, _model)
            #     _, batch = _loader_iter.next()
            #     images, labels, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred1, pred2 = _model(images)
            #     pred1 = interp(pred1)
            #     pred2 = interp(pred2)

            #     loss_seg1 = loss_calc(pred1, labels, args.gpu)
            #     loss_seg2 = loss_calc(pred2, labels, args.gpu)
            #     loss = loss_seg2 + args.lambda_seg * loss_seg1

            #     # proper normalization
            #     loss = loss / args.iter_size
            #     loss.backward()

            #     return loss_seg1, loss_seg2

            # _, batch = targetloader_iter.next()
            # images, _, _ = batch
            # images = Variable(images).cuda(args.gpu)

            # pred1_tgt, pred2_tgt = model(images)
            # pred1_tgt = interp_target(pred1_tgt)
            # pred2_tgt = interp_target(pred2_tgt)

            # loss_seg1 = loss_calc(pred1, labels, args.gpu)
            # loss_seg2 = loss_calc(pred2, labels, args.gpu)
            # loss = loss_seg2 + args.lambda_seg * loss_seg1
            # loss = loss / args.iter_size
            # loss.backward()

            # train with source
            loss_seg1, loss_seg2 = _forward_single(trainloader_iter)
            loss_seg1_src += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_src += loss_seg2.data.cpu().numpy()[0] / args.iter_size

            # train with target
            loss_seg1, loss_seg2 = _forward_single(targetloader_iter)
            loss_seg1_tgt += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            loss_seg2_tgt += loss_seg2.data.cpu().numpy()[0] / args.iter_size

        optimizer.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes,
                             train_bn=False,
                             norm_style='in')
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model.eval()
    model.cuda()

    testloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                             args.data_list,
                                             crop_size=(640, 1280),
                                             resize_size=(1280, 640),
                                             mean=IMG_MEAN,
                                             scale=False,
                                             mirror=False),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, batch in enumerate(testloader):
        if (index * batchsize) % 100 == 0:
            print('%d processd' % (index * batchsize))
        image, _, _, name = batch
        print(image.shape)

        inputs = Variable(image).cuda()
        if args.model == 'DeeplabMulti':
            output1, output2 = model(inputs)
            output_batch = interp(sm(0.5 * output1 +
                                     output2)).cpu().data.numpy()
            #output1, output2 = model(fliplr(inputs))
            #output2 = fliplr(output2)
            #output_batch += interp(output2).cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            output.save('%s/%s' % (args.save, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (args.save, name_tmp.split('.')[0]))

    return args.save
Exemple #13
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        # if args.restore_from[:4] == 'http' :
        #     saved_state_dict = model_zoo.load_url(args.restore_from)
        # else:
        #     saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        # for i in saved_state_dict:
        #     # Scale.layer5.conv2d_list.3.weight
        #     i_parts = i.split('.')
        #     # print i_parts
        #     if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #         new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #         # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # trainloader = data.DataLoader(
    #     GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                 crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
    #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    train_set = spacenet.Spacenet(city=config.dataset, split='train', img_root=config.img_root)
    trainloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=True)

    trainloader_iter = enumerate(trainloader)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
    #                                                  set=args.set),
    #                                batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
    #                                pin_memory=True)

    target_set = spacenet.Spacenet(city=config.target, split='train', img_root=config.img_root)
    targetloader = DataLoader(target_set, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, drop_last=True)
    targetloader_iter = enumerate(targetloader)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            try:
                _, batch = trainloader_iter.__next__()
            except StopIteration:
                trainloader_iter = enumerate(trainloader)
                _, batch = trainloader_iter.__next__()
            images, labels = batch
            # print(images)
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # loss_seg_value1 += loss_seg1.data.cpu().numpy()[0] / args.iter_size
            # loss_seg_value2 += loss_seg2.data.cpu().numpy()[0] / args.iter_size
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            try:
                _, batch = targetloader_iter.__next__()
            except StopIteration:
                targetloader_iter = enumerate(targetloader)
                _, batch = targetloader_iter.__next__()
            images, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(D_out1,
                                       Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(
                                           args.gpu))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            # loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy()[0] / args.iter_size
            # loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy()[0] / args.iter_size
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy() / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(D_out1,
                              Variable(torch.FloatTensor(D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'paris_' + str(i_iter) + '_D2.pth'))
Exemple #14
0
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
                    crop_size=input_size,
                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
                                                     max_iters=args.num_steps * args.iter_size * args.batch_size,
                                                     crop_size=input_size_target,
                                                     scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
                                                     set=args.set),
                                   batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
                                   pin_memory=True)


    targetloader_iter = enumerate(targetloader)

    """
    initialize the discriminator
    model_D1 = 
    model_D2 = 
    """

    """
    set the optimizer
    please refer to the main paper section 5 (Network Training)
    use CLASS METHOD optim_parameters to get the parameters of segmentation model (i.e., model.optim_parameters)
    """

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        """
        bce_loss = 
        """
        pass

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G
            """STEP1: train the segmentation model with source GTA5 data
            You should freeze the discriminator during training G
            """

            """STEP2: learn target to source alignment with target Cityscape data            
            """

            # train D
            """STEP3: train the discriminator with source
            You should unfreeze the discriminator
            """

            """STEP4: train the discriminator with target
            """

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
        'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'.format(
            i_iter, args.num_steps, loss_seg_value1, loss_seg_value2, loss_adv_target_value1, loss_adv_target_value2, loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print 'save model ...'
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print 'taking snapshot ...'
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
Exemple #15
0
def main():
    """Create the model and start the evaluation process."""

    for i in range(1, 61):
        model_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}.pth'.format(i *
                                                                        2000)
        model_D_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}_D.pth'.format(
            i * 2000)
        save_path = './result/GTA2Cityscapes_{0:d}'.format(i * 2000)
        args = get_arguments()

        gpu0 = args.gpu

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        model = DeeplabMulti(num_classes=args.num_classes)
        saved_state_dict = torch.load(model_path)
        model.load_state_dict(saved_state_dict)
        model.eval()
        model.cuda(gpu0)

        num_class_list = [2048, 19]
        model_D = nn.ModuleList([
            FCDiscriminator(num_classes=num_class_list[i])
            if i < 1 else OutspaceDiscriminator(num_classes=num_class_list[i])
            for i in range(2)
        ])
        model_D.load_state_dict(torch.load(model_D_path))
        model_D.eval()
        model_D.cuda(gpu0)

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        with torch.no_grad():
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processd' % index)
                image, _, name = batch
                feat, pred = model(
                    Variable(image).cuda(gpu0), model_D, 'target')

                output = interp(pred).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                name = name[0].split('/')[-1]
                output.save('%s/%s' % (save_path, name))

                output_col.save('%s/%s_color.png' %
                                (save_path, name.split('.')[0]))

        print(save_path)
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    config_path = os.path.join(os.path.dirname(args.restore_from), 'opts.yaml')
    with open(config_path, 'r') as stream:
        config = yaml.load(stream)

    args.model = config['model']
    print('ModelType:%s' % args.model)
    print('NormType:%s' % config['norm_style'])
    gpu0 = args.gpu
    batchsize = args.batchsize

    model_name = os.path.basename(os.path.dirname(args.restore_from))
    #args.save += model_name

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes,
                             use_se=config['use_se'],
                             train_bn=False,
                             norm_style=config['norm_style'])
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)

    try:
        model.load_state_dict(saved_state_dict)
    except:
        model = torch.nn.DataParallel(model)
        model.load_state_dict(saved_state_dict)
    model = torch.nn.DataParallel(model)
    model.eval()
    model.cuda(gpu0)

    testloader = data.DataLoader(robotDataSet(args.data_dir,
                                              args.data_list,
                                              crop_size=(960, 1280),
                                              resize_size=(1280, 960),
                                              mean=IMG_MEAN,
                                              scale=False,
                                              mirror=False,
                                              set=args.set),
                                 batch_size=batchsize,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=4)

    scale = 1.25
    testloader2 = data.DataLoader(robotDataSet(
        args.data_dir,
        args.data_list,
        crop_size=(round(960 * scale), round(1280 * scale)),
        resize_size=(round(1280 * scale), round(960 * scale)),
        mean=IMG_MEAN,
        scale=False,
        mirror=False,
        set=args.set),
                                  batch_size=batchsize,
                                  shuffle=False,
                                  pin_memory=True,
                                  num_workers=4)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(960, 1280),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(960, 1280), mode='bilinear')

    sm = torch.nn.Softmax(dim=1)
    for index, img_data in enumerate(zip(testloader, testloader2)):
        batch, batch2 = img_data
        image, _, _, name = batch
        image2, _, _, name2 = batch2
        print(image.shape)

        inputs = image.cuda()
        inputs2 = image2.cuda()
        print('\r>>>>Extracting feature...%04d/%04d' %
              (index * batchsize, NUM_STEPS),
              end='')
        if args.model == 'DeepLab':
            with torch.no_grad():
                output1, output2 = model(inputs)
                output_batch = interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs

                output1, output2 = model(inputs2)
                output_batch += interp(sm(0.5 * output1 + output2))
                output1, output2 = model(fliplr(inputs2))
                output1, output2 = fliplr(output1), fliplr(output2)
                output_batch += interp(sm(0.5 * output1 + output2))
                del output1, output2, inputs2
                output_batch = output_batch.cpu().data.numpy()
        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            output_batch = model(Variable(image).cuda())
            output_batch = interp(output_batch).cpu().data.numpy()

        output_batch = output_batch.transpose(0, 2, 3, 1)
        score_batch = np.max(output_batch, axis=3)
        output_batch = np.asarray(np.argmax(output_batch, axis=3),
                                  dtype=np.uint8)
        #output_batch[score_batch<3.6] = 255  #3.6 = 4*0.9

        for i in range(output_batch.shape[0]):
            output = output_batch[i, :, :]
            output_col = colorize_mask(output)
            output = Image.fromarray(output)

            name_tmp = name[i].split('/')[-1]
            dir_name = name[i].split('/')[-2]
            save_path = args.save + '/' + dir_name
            #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo')
            #print(save_path)
            if not os.path.isdir(save_path):
                os.mkdir(save_path)
            output.save('%s/%s' % (save_path, name_tmp))
            print('%s/%s' % (save_path, name_tmp))
            output_col.save('%s/%s_color.png' %
                            (save_path, name_tmp.split('.')[0]))

    return args.save
Exemple #17
0
class AD_Trainer(nn.Module):
    def __init__(self, args):
        super(AD_Trainer, self).__init__()
        self.fp16 = args.fp16
        self.class_balance = args.class_balance
        self.often_balance = args.often_balance
        self.num_classes = args.num_classes
        self.class_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.often_weight = torch.FloatTensor(
            self.num_classes).zero_().cuda() + 1
        self.multi_gpu = args.multi_gpu
        self.only_hard_label = args.only_hard_label
        if args.model == 'DeepLab':
            self.G = DeeplabMulti(num_classes=args.num_classes,
                                  use_se=args.use_se,
                                  train_bn=args.train_bn,
                                  norm_style=args.norm_style,
                                  droprate=args.droprate)
            if args.restore_from[:4] == 'http':
                saved_state_dict = model_zoo.load_url(args.restore_from)
            else:
                saved_state_dict = torch.load(args.restore_from)

            new_params = self.G.state_dict().copy()
            for i in saved_state_dict:
                # Scale.layer5.conv2d_list.3.weight
                i_parts = i.split('.')
                # print i_parts
                if args.restore_from[:4] == 'http':
                    if i_parts[1] != 'fc' and i_parts[1] != 'layer5':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                else:
                    #new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                    if i_parts[0] == 'module':
                        new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[1:])
                    else:
                        new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
                        print('%s is loaded from pre-trained weight.\n' %
                              i_parts[0:])
        self.G.load_state_dict(new_params)

        self.D1 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D2 = MsImageDis(input_dim=args.num_classes).cuda()
        self.D1.apply(weights_init('gaussian'))
        self.D2.apply(weights_init('gaussian'))

        if self.multi_gpu and args.sync_bn:
            print("using apex synced BN")
            self.G = apex.parallel.convert_syncbn_model(self.G)

        self.gen_opt = optim.SGD(self.G.optim_parameters(args),
                                 lr=args.learning_rate,
                                 momentum=args.momentum,
                                 nesterov=True,
                                 weight_decay=args.weight_decay)

        self.dis1_opt = optim.Adam(self.D1.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.dis2_opt = optim.Adam(self.D2.parameters(),
                                   lr=args.learning_rate_D,
                                   betas=(0.9, 0.99))

        self.seg_loss = nn.CrossEntropyLoss(ignore_index=255)
        self.kl_loss = nn.KLDivLoss(size_average=False)
        self.sm = torch.nn.Softmax(dim=1)
        self.log_sm = torch.nn.LogSoftmax(dim=1)
        self.G = self.G.cuda()
        self.D1 = self.D1.cuda()
        self.D2 = self.D2.cuda()
        self.interp = nn.Upsample(size=args.crop_size,
                                  mode='bilinear',
                                  align_corners=True)
        self.interp_target = nn.Upsample(size=args.crop_size,
                                         mode='bilinear',
                                         align_corners=True)
        self.lambda_seg = args.lambda_seg
        self.max_value = args.max_value
        self.lambda_me_target = args.lambda_me_target
        self.lambda_kl_target = args.lambda_kl_target
        self.lambda_adv_target1 = args.lambda_adv_target1
        self.lambda_adv_target2 = args.lambda_adv_target2
        self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1
        if args.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.G, self.gen_opt = amp.initialize(self.G,
                                                  self.gen_opt,
                                                  opt_level="O1")
            self.D1, self.dis1_opt = amp.initialize(self.D1,
                                                    self.dis1_opt,
                                                    opt_level="O1")
            self.D2, self.dis2_opt = amp.initialize(self.D2,
                                                    self.dis2_opt,
                                                    opt_level="O1")

    def update_class_criterion(self, labels):
        weight = torch.FloatTensor(self.num_classes).zero_().cuda()
        weight += 1
        count = torch.FloatTensor(self.num_classes).zero_().cuda()
        often = torch.FloatTensor(self.num_classes).zero_().cuda()
        often += 1
        print(labels.shape)
        n, h, w = labels.shape
        for i in range(self.num_classes):
            count[i] = torch.sum(labels == i)
            if count[i] < 64 * 64 * n:  #small objective
                weight[i] = self.max_value
        if self.often_balance:
            often[count == 0] = self.max_value

        self.often_weight = 0.9 * self.often_weight + 0.1 * often
        self.class_weight = weight * self.often_weight
        print(self.class_weight)
        return nn.CrossEntropyLoss(weight=self.class_weight, ignore_index=255)

    def update_label(self, labels, prediction):
        criterion = nn.CrossEntropyLoss(weight=self.class_weight,
                                        ignore_index=255,
                                        reduction='none')
        #criterion = self.seg_loss
        loss = criterion(prediction, labels)
        print('original loss: %f' % self.seg_loss(prediction, labels))
        #mm = torch.median(loss)
        loss_data = loss.data.cpu().numpy()
        mm = np.percentile(loss_data[:], self.only_hard_label)
        #print(m.data.cpu(), mm)
        labels[loss < mm] = 255
        return labels

    def gen_update(self, images, images_t, labels, labels_t, i_iter):
        self.gen_opt.zero_grad()

        pred1, pred2 = self.G(images)
        pred1 = self.interp(pred1)
        pred2 = self.interp(pred2)

        if self.class_balance:
            self.seg_loss = self.update_class_criterion(labels)

        if self.only_hard_label > 0:
            labels1 = self.update_label(labels.clone(), pred1)
            labels2 = self.update_label(labels.clone(), pred2)
            loss_seg1 = self.seg_loss(pred1, labels1)
            loss_seg2 = self.seg_loss(pred2, labels2)
        else:
            loss_seg1 = self.seg_loss(pred1, labels)
            loss_seg2 = self.seg_loss(pred2, labels)

        loss = loss_seg2 + self.lambda_seg * loss_seg1

        # target
        pred_target1, pred_target2 = self.G(images_t)
        pred_target1 = self.interp_target(pred_target1)
        pred_target2 = self.interp_target(pred_target2)

        if self.multi_gpu:
            #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0:
            loss_adv_target1 = self.D1.module.calc_gen_loss(
                self.D1, input_fake=F.softmax(pred_target1, dim=1))
            loss_adv_target2 = self.D2.module.calc_gen_loss(
                self.D2, input_fake=F.softmax(pred_target2, dim=1))
            #else:
            #    print('skip the discriminator')
            #    loss_adv_target1, loss_adv_target2 = 0, 0
        else:
            #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0:
            loss_adv_target1 = self.D1.calc_gen_loss(self.D1,
                                                     input_fake=F.softmax(
                                                         pred_target1, dim=1))
            loss_adv_target2 = self.D2.calc_gen_loss(self.D2,
                                                     input_fake=F.softmax(
                                                         pred_target2, dim=1))
            #else:
            #loss_adv_target1 = 0.0 #torch.tensor(0).cuda()
            #loss_adv_target2 = 0.0 #torch.tensor(0).cuda()

        loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2

        if i_iter < 15000:
            self.lambda_kl_target_copy = 0
            self.lambda_me_target_copy = 0
        else:
            self.lambda_kl_target_copy = self.lambda_kl_target
            self.lambda_me_target_copy = self.lambda_me_target

        loss_me = 0.0
        if self.lambda_me_target_copy > 0:
            confidence_map = torch.sum(
                self.sm(0.5 * pred_target1 + pred_target2)**2, 1).detach()
            loss_me = -torch.mean(confidence_map * torch.sum(
                self.sm(0.5 * pred_target1 + pred_target2) *
                self.log_sm(0.5 * pred_target1 + pred_target2), 1))
            loss += self.lambda_me_target * loss_me

        loss_kl = 0.0
        if self.lambda_kl_target_copy > 0:
            n, c, h, w = pred_target1.shape
            with torch.no_grad():
                #pred_target1_flip, pred_target2_flip = self.G(fliplr(images_t))
                #pred_target1_flip = self.interp_target(pred_target1_flip)
                #pred_target2_flip = self.interp_target(pred_target2_flip)
                mean_pred = self.sm(
                    0.5 * pred_target1 + pred_target2
                )  #+ self.sm(fliplr(0.5*pred_target1_flip + pred_target2_flip)) ) /2
            loss_kl = (self.kl_loss(self.log_sm(pred_target2), mean_pred) +
                       self.kl_loss(self.log_sm(pred_target1), mean_pred)) / (
                           n * h * w)
            #loss_kl = (self.kl_loss(self.log_sm(pred_target2) , self.sm(pred_target1) ) ) / (n*h*w) + (self.kl_loss(self.log_sm(pred_target1) , self.sm(pred_target2)) ) / (n*h*w)
            print(loss_kl)
            loss += self.lambda_kl_target * loss_kl

        if self.fp16:
            with amp.scale_loss(loss, self.gen_opt) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.gen_opt.step()

        val_loss = self.seg_loss(pred_target2, labels_t)

        return loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss

    def dis_update(self, pred1, pred2, pred_target1, pred_target2):
        self.dis1_opt.zero_grad()
        self.dis2_opt.zero_grad()
        pred1 = pred1.detach()
        pred2 = pred2.detach()
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        if self.multi_gpu:
            loss_D1, reg1 = self.D1.module.calc_dis_loss(
                self.D1,
                input_fake=F.softmax(pred_target1, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
            loss_D2, reg2 = self.D2.module.calc_dis_loss(
                self.D2,
                input_fake=F.softmax(pred_target2, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
        else:
            loss_D1, reg1 = self.D1.calc_dis_loss(
                self.D1,
                input_fake=F.softmax(pred_target1, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))
            loss_D2, reg2 = self.D2.calc_dis_loss(
                self.D2,
                input_fake=F.softmax(pred_target2, dim=1),
                input_real=F.softmax(0.5 * pred1 + pred2, dim=1))

        loss = loss_D1 + loss_D2
        if self.fp16:
            with amp.scale_loss(loss,
                                [self.dis1_opt, self.dis2_opt]) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.dis1_opt.step()
        self.dis2_opt.step()
        return loss_D1, loss_D2
def main(args):
    """Create the model and start the training."""

    mkdir_check(args.snapshot_dir)
    writer = SummaryWriter(log_dir=os.path.join(args.snapshot_dir, 'tb-logs'))

    h, w = map(int, args.src_input_size.split(','))
    src_input_size = (h, w)

    h, w = map(int, args.tgt_input_size.split(','))
    tgt_input_size = (h, w)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http' :
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D2 = FCDiscriminator(num_classes=19)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
            ListDataSet(args.src_data_dir, args.src_img_list, args.src_lbl_list, 
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=src_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_list, args.tgt_lbl_list,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_nolabel = data.DataLoader(
            ListDataSet(args.tgt_data_dir, args.tgt_img_nolabel_list, None,
                max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=tgt_input_size, mean=IMG_MEAN),
            batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    targetloader_nolabel_iter = enumerate(targetloader_nolabel)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()

    interp = nn.Upsample(size=(src_input_size[1], src_input_size[0]), mode='bilinear', align_corners=True)
    interp_target = nn.Upsample(size=(tgt_input_size[1], tgt_input_size[0]), mode='bilinear', align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value2 = 0
        loss_tgt_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(args, optimizer, i_iter)

        optimizer_D2.zero_grad()
        adjust_learning_rate_D(args, optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = trainloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            #pred1 = interp(pred1)
            pred2 = interp(pred2)

            #loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = loss_seg2 #+ args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target seg

            _, batch = targetloader_iter.__next__()
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            #loss_tgt_seg1 = loss_calc(pred_target1, labels, args.gpu)
            loss_tgt_seg2 = loss_calc(pred_target2, labels, args.gpu)
            loss = loss_tgt_seg2 #+ args.lambda_seg * loss_tgt_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward(retain_graph=True)
            loss_tgt_seg_value2 += loss_tgt_seg2.data.cpu().numpy() / args.iter_size

            # train with target_nolabel adv
            _, batch = targetloader_nolabel_iter.__next__()
            images, _, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            #pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_adv_target2 = bce_loss(D_out2,
                                        Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(
                                            args.gpu))

            loss = args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy() / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred2 = pred2.detach()

            D_out2 = model_D2(F.softmax(pred2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target2 = pred_target2.detach()

            D_out2 = model_D2(F.softmax(pred_target2, dim=-1))

            loss_D2 = bce_loss(D_out2,
                               Variable(torch.FloatTensor(D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D2.backward()

            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D2.step()

        print(
        'iter = {:5d}/{:8d}, loss_seg2 = {:.3f} loss_tgt_seg2 = {:.3f} loss_adv2 = {:.3f} loss_D2 = {:.3f}'.format(
            i_iter, args.num_steps_stop, loss_seg_value2, loss_tgt_seg_value2,
            loss_adv_target_value2, loss_D_value2))

        writer.add_scalars('loss/seg', {
            'src2': loss_seg_value2, 
            'tgt2': loss_tgt_seg_value2,
            }, i_iter)

        writer.add_scalar('loss/adv', loss_adv_target_value2, i_iter)
        writer.add_scalar('loss/d', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'model_{}.pth'.format(i_iter)))
            torch.save(model_D2.state_dict(), osp.join(args.snapshot_dir, 'model_d_{}.pth'.format(i_iter)))
def main():
    """Create the model and start the training."""

    gpu_id_2 = 3
    gpu_id_1 = 2

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    # gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            print("from url")
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            print("from restore")
            saved_state_dict = torch.load(args.restore_from)
            saved_state_dict = torch.load(
                'snapshots/GTA2Cityscapes_multi_54/GTA5_10000.pth')
            model.load_state_dict(saved_state_dict)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        # model.load_state_dict(new_params)

    model.train()
    model.cuda(gpu_id_2)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    #model_D2 = model_D1
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(gpu_id_1)

    model_D2.train()
    model_D2.cuda(gpu_id_1)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    _, batch_last = trainloader_iter.next()

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    # print(args.num_steps * args.iter_size * args.batch_size, trainloader.__len__())

    targetloader_iter = enumerate(targetloader)
    _, batch_last_target = targetloader_iter.next()

    # for i in range(200):
    #     _, batch = targetloader_iter.__next__()
    # exit()

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    bce_loss = torch.nn.MSELoss()

    def upsample_(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size[1], input_size[0]),
                                         mode='bilinear',
                                         align_corners=False)

    def upsample_target(input_):
        return nn.functional.interpolate(input_,
                                         size=(input_size_target[1],
                                               input_size_target[0]),
                                         mode='bilinear',
                                         align_corners=False)

    interp = upsample_
    interp_target = upsample_target

    # labels for adversarial training
    source_label = 1
    target_label = -1
    mix_label = 0

    for i_iter in range(10000, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            def result_model(batch, interp_):
                images, labels, _, name = batch
                images = Variable(images).cuda(gpu_id_2)
                labels = Variable(labels.long()).cuda(gpu_id_1)
                pred1, pred2 = model(images)
                pred1 = interp_(pred1)
                pred2 = interp_(pred2)
                pred1_ = pred1.cuda(gpu_id_1)
                pred2_ = pred2.cuda(gpu_id_1)
                return pred1_, pred2_, labels

            # train with source
            # _, batch = trainloader_iter.next()
            _, batch = trainloader_iter.next()
            _, batch_target = targetloader_iter.next()
            pred1, pred2, labels = result_model(batch, interp)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_1 = loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_2 = loss_seg2.data.cpu().numpy() / args.iter_size
            # print(loss_seg_1, loss_seg_2)

            pred1, pred2, labels = result_model(batch_target, interp_target)
            loss_seg1 = loss_calc(pred1, labels, gpu_id_1)
            loss_seg2 = loss_calc(pred2, labels, gpu_id_1)
            loss = loss_seg2 + args.lambda_seg * loss_seg1
            loss = loss / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # output = pred2.cpu().data[0].numpy()
            # real_lab = labels.cpu().data[0].numpy()
            # output = output.transpose(1,2,0)
            # print(real_lab.shape, output.shape)
            # output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
            # output_col = colorize_mask(output)
            # real_lab_col = colorize_mask(real_lab)
            # output = Image.fromarray(output)
            # # name[0].split('/')[-1]
            # # print('result/train_seg_result/' + name[0][len(name[0])-23:len(name[0])-4] + '_color.png')
            # output_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_color.png')
            # real_lab_col.save('result/train_seg_result/' + name[0].split('/')[-1] + '_real.png')
            # print(loss_seg_value1, loss_seg_value2)
            # if i_iter == 100:
            #     exit()
            # else:
            #     break

            # train with target

            #_, batch = targetloader_iter.next()
            # images, _, _ = target_batch
            # images_target = Variable(images_target).cuda(gpu_id_2)

            pred1_last_target, pred2_last_target, labels_last_target = result_model(
                batch_last_target, interp_target)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            fake1_D = torch.cat((pred1_target_D, pred1_last_target_D), dim=1)
            fake2_D = torch.cat((pred2_target_D, pred2_last_target_D), dim=1)
            D_out_fake_1 = model_D1(fake1_D)
            D_out_fake_2 = model_D2(fake1_D)

            loss_adv_fake1 = bce_loss(
                D_out_fake_1,
                Variable(
                    torch.FloatTensor(D_out_fake_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_fake2 = bce_loss(
                D_out_fake_2,
                Variable(
                    torch.FloatTensor(D_out_fake_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_fake1
            loss_adv_target2 = loss_adv_fake2
            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()

            pred1, pred2, labels = result_model(batch, interp)
            pred1_target, pred2_target, labels_target = result_model(
                batch_target, interp_target)

            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            mix1_D = torch.cat((pred1_target_D, pred1_D), dim=1)
            mix2_D = torch.cat((pred2_target_D, pred2_D), dim=1)

            D_out_mix_1 = model_D1(mix1_D)
            D_out_mix_2 = model_D2(mix2_D)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_adv_mix1 = bce_loss(
                D_out_mix_1,
                Variable(
                    torch.FloatTensor(D_out_mix_1.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_mix2 = bce_loss(
                D_out_mix_2,
                Variable(
                    torch.FloatTensor(D_out_mix_2.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_adv_target1 = loss_adv_mix1 * 2
            loss_adv_target2 = loss_adv_mix2 * 2

            loss = args.lambda_adv_target1 * loss_adv_target1.cuda(
                gpu_id_1) + args.lambda_adv_target2 * loss_adv_target2.cuda(
                    gpu_id_1)
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1_last, pred2_last, labels_last = result_model(
                batch_last, interp)

            # train with source

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last = pred1_last.detach().cuda(gpu_id_1)
            pred2_last = pred2_last.detach().cuda(gpu_id_1)
            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_D = F.softmax((pred1_last), dim=1)
            pred2_last_D = F.softmax((pred2_last), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            real1_D = torch.cat((pred1_D, pred1_last_D), dim=1)
            real2_D = torch.cat((pred2_D, pred2_last_D), dim=1)
            mix1_D_ = torch.cat((pred1_last_D, pred1_target_D), dim=1)
            mix2_D_ = torch.cat((pred2_last_D, pred2_target_D), dim=1)

            D_out1_real = model_D1(real1_D)
            D_out2_real = model_D2(real2_D)
            D_out1_mix = model_D1(mix1_D_)
            D_out2_mix = model_D2(mix2_D_)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1_real,
                Variable(
                    torch.FloatTensor(D_out1_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2_real,
                Variable(
                    torch.FloatTensor(D_out2_real.data.size()).fill_(
                        source_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out1_mix,
                Variable(
                    torch.FloatTensor(D_out1_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out2_mix,
                Variable(
                    torch.FloatTensor(D_out2_mix.data.size()).fill_(
                        mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target

            pred1 = pred1.detach().cuda(gpu_id_1)
            pred2 = pred2.detach().cuda(gpu_id_1)
            pred1_target = pred1_target.detach().cuda(gpu_id_1)
            pred2_target = pred2_target.detach().cuda(gpu_id_1)
            pred1_last_target = pred1_last_target.detach().cuda(gpu_id_1)
            pred2_last_target = pred2_last_target.detach().cuda(gpu_id_1)

            pred1_D = F.softmax((pred1), dim=1)
            pred2_D = F.softmax((pred2), dim=1)
            pred1_last_target_D = F.softmax((pred1_last_target), dim=1)
            pred2_last_target_D = F.softmax((pred2_last_target), dim=1)
            pred1_target_D = F.softmax((pred1_target), dim=1)
            pred2_target_D = F.softmax((pred2_target), dim=1)

            fake1_D_ = torch.cat((pred1_target_D, pred1_target_D), dim=1)
            fake2_D_ = torch.cat((pred2_target_D, pred2_target_D), dim=1)
            mix1_D__ = torch.cat((pred1_D, pred1_last_target_D), dim=1)
            mix2_D__ = torch.cat((pred2_D, pred2_last_target_D), dim=1)

            # pred_target1 = pred_target1.detach().cuda(gpu_id_1)
            # pred_target2 = pred_target2.detach().cuda(gpu_id_1)

            D_out1 = model_D1(fake1_D_)
            D_out2 = model_D2(fake2_D_)
            D_out3 = model_D1(mix1_D__)
            D_out4 = model_D2(mix2_D__)

            # D_out1 = D_out1.cuda(gpu_id_1)
            # D_out2 = D_out2.cuda(gpu_id_1)

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(D_out1.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(D_out2.data.size()).fill_(
                        target_label)).cuda(gpu_id_1))

            loss_D3 = bce_loss(
                D_out3,
                Variable(
                    torch.FloatTensor(
                        D_out3.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D4 = bce_loss(
                D_out4,
                Variable(
                    torch.FloatTensor(
                        D_out4.data.size()).fill_(mix_label)).cuda(gpu_id_1))

            loss_D1 = (loss_D1 + loss_D3) / args.iter_size / 2
            loss_D2 = (loss_D2 + loss_D4) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            batch_last, batch_last_target = batch, batch_target
            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
Exemple #20
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = 2
    torch.cuda.manual_seed(1337)
    torch.cuda.set_device(2)

    if not os.path.exists(args.save):
        os.makedirs(args.save)
    for i in range(5):
        if not os.path.exists(args.save + '/' + str(i)):
            os.makedirs(args.save + '/' + str(i))

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG
    print("begin")

    if args.restore_from[:4] == 'http':
        print("1112222")
        #saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        print("2222222", gpu0)
        # saved_state_dict = torch.load(args.restore_from)
        print(args.restore_from)
        model.load_state_dict(torch.load(args.restore_from))
    model.cuda(gpu0)
    # print(sys.getsizeof(model))
    # model.eval()
    # exit()

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                   args.data_list,
                                                   crop_size=(1024, 512),
                                                   mean=IMG_MEAN,
                                                   scale=False,
                                                   mirror=False,
                                                   set=args.set),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        # interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)
        interp = Upsample_function

    else:
        interp = nn.Upsample(size=(1024, 2048), mode='bilinear')
    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processd' % index)
            image, labels, _, name = batch
            image = Variable(image).cuda(gpu0)
            final = []
            if args.model == 'DeeplabMulti':
                output1, output2 = model(image)
                output1 = F.softmax(output1, 1)
                output2 = F.softmax(output2, 1)
                for i in [0, 3, 7, 10]:
                    final_output = i / 10.0 * output1 + (10.0 -
                                                         i) / 10.0 * output2
                    output = interp(final_output).cpu().data[0].numpy()
                    final.append(output)
                    break
                labels = labels.cpu().data[0].numpy()
            elif args.model == 'DeeplabVGG':
                output = model(Variable(image, volatile=True).cuda(gpu0))
                output = interp(output).cpu().data[0].numpy()

            name = name[0].split('/')[-1]
            # labels_col = colorize_mask(labels)
            # labels_col.save('%s/%s_real.png' % (args.save, name.split('.')[0]))
            for i in range(4):
                output = final[i]
                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                output.save('%s/%s/%s' % (args.save, str(i), name))
                output_col.save('%s/%s/%s_color.png' %
                                (args.save, str(i), name.split('.')[0]))
                break
Exemple #21
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    gpu0 = args.gpu

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.model == 'DeeplabMulti':
        model = DeeplabMulti(num_classes=args.num_classes)
    elif args.model == 'Oracle':
        model = Res_Deeplab(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_ORC
    elif args.model == 'DeeplabVGG':
        model = DeeplabVGG(num_classes=args.num_classes)
        if args.restore_from == RESTORE_FROM:
            args.restore_from = RESTORE_FROM_VGG

    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    ### for running different versions of pytorch
    model_dict = model.state_dict()
    saved_state_dict = {
        k: v
        for k, v in saved_state_dict.items() if k in model_dict
    }
    model_dict.update(saved_state_dict)
    ###
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda(gpu0)

    log_dir = args.save
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)
    exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
    log_dir = os.path.join(log_dir, exp_name)
    writer = SummaryWriter(log_dir)

    # testloader = data.DataLoader(SyntheticSmokeTrain(args={}, dataset_limit=-1, #args.num_steps * args.iter_size * args.batch_size,
    #                 image_shape=(360,640), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)

    testloader = data.DataLoader(SmokeDataset(image_size=(640, 360),
                                              dataset_mean=IMG_MEAN),
                                 batch_size=1,
                                 shuffle=True,
                                 pin_memory=True)
    # testloader = data.DataLoader(SimpleSmokeTrain(args = {}, image_size=(640,360), dataset_mean=IMG_MEAN),
    #                     batch_size=1, shuffle=True, pin_memory=True)
    # testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
    # batch_size=1, shuffle=False, pin_memory=True)

    if version.parse(torch.__version__) >= version.parse('0.4.0'):
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)
    else:
        interp = nn.Upsample(size=(640, 360),
                             mode='bilinear',
                             align_corners=True)

    count = 0
    iou_sum_fg = 0
    iou_count_fg = 0

    iou_sum_bg = 0
    iou_count_bg = 0

    for index, batch in enumerate(testloader):
        if (index + 1) % 100 == 0:
            print('%d processd' % index)
            # print("Processed {}/{}".format(index, len(testloader)))

        # if count > 5:
        #     break
        image, label, name = batch
        if args.model == 'DeeplabMulti':
            with torch.no_grad():
                output1, output2 = model(Variable(image).cuda(gpu0))
            # print(output1.shape)
            # print(output2.shape)
            output = interp(output2).cpu()
            orig_output = output.detach().clone()
            output = output.data[0].numpy()
            # output = (output > 0.5).astype(np.uint8)*255
            # print(np.all(output==0), np.all(output==255))
            # print(np.min(output), np.max(output))

        elif args.model == 'DeeplabVGG' or args.model == 'Oracle':
            with torch.no_grad():
                output = model(Variable(image).cuda(gpu0))
            output = interp(output).cpu().data[0].numpy()

        output = output.transpose(1, 2, 0)
        output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
        classes_seen = set(output.ravel().tolist())
        # print(classes_seen)
        # print(output.shape, name[0])
        output_col = colorize_mask(output)
        output = Image.fromarray(output)
        # print("name", name)
        name = name[0]
        # name = name[0].split('/')[-1]

        if len(classes_seen) > 1:
            count += 1
            print(classes_seen)
            print(Counter(np.asarray(output).ravel()))
            image = image.squeeze()
            for c in range(3):
                image[c, :, :] += IMG_MEAN[c]
                # image2[c,:,:] += IMG_MEAN[2-c]
            image = (image - image.min()) / (image.max() - image.min())
            image = image[[2, 1, 0], :, :]
            print(image.shape, image.min(), image.max())
            output.save(os.path.join(args.save, name + '.png'))
            output_col.save(os.path.join(args.save, name + '_color.png'))
            # output.save('%s/%s.png' % (args.save, name))
            # output_col.save('%s/%s_color.png' % (args.save, name))#.split('.')[0]))

            output_argmaxs = torch.argmax(orig_output.squeeze(), dim=0)
            mask1 = (output_argmaxs == 0).float() * 255
            label = label.squeeze()

            iou_fg = iou_pytorch(mask1, label)
            print("foreground IoU", iou_fg)
            iou_sum_fg += iou_fg
            iou_count_fg += 1

            mask2 = (output_argmaxs > 0).float() * 255
            label2 = label.max() - label

            iou_bg = iou_pytorch(mask2, label2)
            print("IoU for background: ", iou_bg)
            iou_sum_bg += iou_bg
            iou_count_bg += 1

            writer.add_images(f'input_images',
                              tf.resize(image[[2, 1, 0]], [1080, 1920]),
                              index,
                              dataformats='CHW')

            print("shape of label", label.shape)
            label_reshaped = tf.resize(label.unsqueeze(0),
                                       [1080, 1920]).squeeze()
            print("label reshaped: ", label_reshaped.shape)
            writer.add_images(f'labels',
                              label_reshaped,
                              index,
                              dataformats='HW')
            writer.add_images(
                f'output/1',
                255 - np.asarray(tf.resize(output, [1080, 1920])) * 255,
                index,
                dataformats='HW')
            # writer.add_images(f'output/1',np.asarray(output)*255, index,dataformats='HW')
            # writer.add_images(f'output/2',np.asarray(output_col), index, dataformats='HW')
            writer.add_scalar(f'iou/smoke', iou_fg, index)
            writer.add_scalar(f'iou/background', iou_bg, index)
            writer.add_scalar(f'iou/mean', (iou_bg + iou_fg) / 2, index)
            writer.flush()

    if iou_count_fg > 0:
        print("Mean IoU, foreground: {}".format(iou_sum_fg / iou_count_fg))
        print("Mean IoU, background: {}".format(iou_sum_bg / iou_count_bg))
        print("Mean IoU, averaged over classes: {}".format(
            (iou_sum_fg + iou_sum_bg) / (iou_count_fg + iou_count_bg)))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    tau = torch.ones(1) * args.tau
    tau = tau.cuda(args.gpu)

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params, False)
    elif args.model == 'DeepLabVGG':
        model = DeeplabVGG(pretrained=True, num_classes=args.num_classes)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    weak_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Resize(1024),
        transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #         RandomCrop(768)
    ])

    target_transform = transforms.Compose([
        #         transforms.RandomCrop(32, 4),
        #         transforms.RandomRotation(30),
        #         transforms.Normalize(mean, std)
        #         transforms.Resize(1024),
        #         transforms.ToTensor(),
        #         RandomCrop(768)
    ])

    label_set = GTA5(
        root=args.data_dir,
        num_cls=19,
        split='all',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size,
        #              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )
    unlabel_set = Cityscapes(
        root=args.data_dir_target,
        split=args.set,
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                              crop_transform=RandomCrop(int(768*(args.scale/1024))),
    )

    test_set = Cityscapes(
        root=args.data_dir_target,
        split='val',
        remap_labels=True,
        transform=weak_transform,
        target_transform=target_transform,
        scale=input_size_target,
        #                       crop_transform=RandomCrop(768)
    )

    label_loader = data.DataLoader(label_set,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=False)

    unlabel_loader = data.DataLoader(unlabel_set,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=False)

    test_loader = data.DataLoader(test_set,
                                  batch_size=2,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=False)

    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))

    [model, model_D2,
     model_D2], [optimizer, optimizer_D1, optimizer_D2
                 ] = amp.initialize([model, model_D2, model_D2],
                                    [optimizer, optimizer_D1, optimizer_D2],
                                    opt_level="O1",
                                    num_losses=7)

    optimizer.zero_grad()
    optimizer_D1.zero_grad()
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = Interpolate(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = Interpolate(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_test = Interpolate(size=(input_size_target[1],
                                    input_size_target[0]),
                              mode='bilinear',
                              align_corners=True)
    #     interp_test = Interpolate(size=(1024, 2048), mode='bilinear', align_corners=True)

    normalize_transform = transforms.Compose([
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    # labels for adversarial training
    source_label = 0
    target_label = 1

    max_mIoU = 0

    total_loss_seg_value1 = []
    total_loss_adv_target_value1 = []
    total_loss_D_value1 = []
    total_loss_con_value1 = []

    total_loss_seg_value2 = []
    total_loss_adv_target_value2 = []
    total_loss_D_value2 = []
    total_loss_con_value2 = []

    hist = np.zeros((num_cls, num_cls))

    #     for i_iter in range(args.num_steps):
    for i_iter, (batch, batch_un) in enumerate(
            zip(roundrobin_infinite(label_loader),
                roundrobin_infinite(unlabel_loader))):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0
        loss_con_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0
        loss_con_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        # train G

        # don't accumulate grads in D
        for param in model_D1.parameters():
            param.requires_grad = False

        for param in model_D2.parameters():
            param.requires_grad = False

        # train with source

        images, labels = batch
        images_orig = images
        images = transform_batch(images, normalize_transform)
        images = Variable(images).cuda(args.gpu)

        pred1, pred2 = model(images)
        pred1 = interp(pred1)
        pred2 = interp(pred2)

        loss_seg1 = loss_calc(pred1, labels, args.gpu)
        loss_seg2 = loss_calc(pred2, labels, args.gpu)
        loss = loss_seg2 + args.lambda_seg * loss_seg1

        # proper normalization
        loss = loss / args.iter_size

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

#         loss.backward()
        loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
        loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

        # train with target

        images_tar, labels_tar = batch_un
        images_tar_orig = images_tar
        images_tar = transform_batch(images_tar, normalize_transform)
        images_tar = Variable(images_tar).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_tar)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_adv_target1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv_target2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()
        loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
        ) / args.iter_size
        loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
        ) / args.iter_size

        # train with consistency loss
        # unsupervise phase
        policies = RandAugment().get_batch_policy(args.batch_size)
        rand_p1 = np.random.random(size=args.batch_size)
        rand_p2 = np.random.random(size=args.batch_size)
        random_dir = np.random.choice([-1, 1], size=[args.batch_size, 2])

        images_aug = aug_batch_tensor(images_tar_orig, policies, rand_p1,
                                      rand_p2, random_dir)

        images_aug_orig = images_aug
        images_aug = transform_batch(images_aug, normalize_transform)
        images_aug = Variable(images_aug).cuda(args.gpu)

        pred_target_aug1, pred_target_aug2 = model(images_aug)
        pred_target_aug1 = interp_target(pred_target_aug1)
        pred_target_aug2 = interp_target(pred_target_aug2)

        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        max_pred1, psuedo_label1 = torch.max(F.softmax(pred_target1, dim=1), 1)
        max_pred2, psuedo_label2 = torch.max(F.softmax(pred_target2, dim=1), 1)

        psuedo_label1 = psuedo_label1.cpu().numpy().astype(np.float32)
        psuedo_label1_thre = psuedo_label1.copy()
        psuedo_label1_thre[(max_pred1 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label1_thre = aug_batch_numpy(psuedo_label1_thre, policies,
                                             rand_p1, rand_p2, random_dir)
        psuedo_label2 = psuedo_label2.cpu().numpy().astype(np.float32)
        psuedo_label2_thre = psuedo_label2.copy()
        psuedo_label2_thre[(max_pred2 < tau).cpu().numpy().astype(
            np.bool)] = 255  # threshold to don't care
        psuedo_label2_thre = aug_batch_numpy(psuedo_label2_thre, policies,
                                             rand_p1, rand_p2, random_dir)

        psuedo_label1_thre = Variable(psuedo_label1_thre).cuda(args.gpu)
        psuedo_label2_thre = Variable(psuedo_label2_thre).cuda(args.gpu)

        if (psuedo_label1_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con1 = loss_calc(pred_target_aug1, psuedo_label1_thre,
                                  args.gpu)
            loss_con_value1 += loss_con1.data.cpu().numpy() / args.iter_size
        else:
            loss_con1 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        if (psuedo_label2_thre != 255).sum().cpu().numpy() > 0:
            # nll_loss doesn't support empty tensors
            loss_con2 = loss_calc(pred_target_aug2, psuedo_label2_thre,
                                  args.gpu)
            loss_con_value2 += loss_con2.data.cpu().numpy() / args.iter_size
        else:
            loss_con2 = torch.tensor(0.0, requires_grad=True).cuda(args.gpu)

        loss = args.lambda_con * loss_con1 + args.lambda_con * loss_con2
        # proper normalization
        loss = loss / args.iter_size
        with amp.scale_loss(loss, optimizer, loss_id=2) as scaled_loss:
            scaled_loss.backward()
#         loss.backward()

# train D

# bring back requires_grad
        for param in model_D1.parameters():
            param.requires_grad = True

        for param in model_D2.parameters():
            param.requires_grad = True

        # train with source
        pred1 = pred1.detach()
        pred2 = pred2.detach()

        D_out1 = model_D1(F.softmax(pred1, dim=1))
        D_out2 = model_D2(F.softmax(pred2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=3) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=4) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        # train with target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()

        D_out1 = model_D1(F.softmax(pred_target1, dim=1))
        D_out2 = model_D2(F.softmax(pred_target2, dim=1))

        loss_D1 = bce_loss(
            D_out1,
            Variable(
                torch.FloatTensor(
                    D_out1.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D2 = bce_loss(
            D_out2,
            Variable(
                torch.FloatTensor(
                    D_out2.data.size()).fill_(target_label)).cuda(args.gpu))

        loss_D1 = loss_D1 / args.iter_size / 2
        loss_D2 = loss_D2 / args.iter_size / 2

        with amp.scale_loss(loss_D1, optimizer_D1, loss_id=5) as scaled_loss:
            scaled_loss.backward()
#         loss_D1.backward()
        with amp.scale_loss(loss_D2, optimizer_D2, loss_id=6) as scaled_loss:
            scaled_loss.backward()
#         loss_D2.backward()

        loss_D_value1 += loss_D1.data.cpu().numpy()
        loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}, loss_con1 = {8:.3f}, loss_con2 = {9:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2, loss_con_value1,
                    loss_con_value2))

        total_loss_seg_value1.append(loss_seg_value1)
        total_loss_adv_target_value1.append(loss_adv_target_value1)
        total_loss_D_value1.append(loss_D_value1)
        total_loss_con_value1.append(loss_con_value1)

        total_loss_seg_value2.append(loss_seg_value2)
        total_loss_adv_target_value2.append(loss_adv_target_value2)
        total_loss_D_value2.append(loss_D_value2)
        total_loss_con_value2.append(loss_con_value2)

        hist += fast_hist(
            labels.cpu().numpy().flatten().astype(int),
            torch.argmax(pred2, dim=1).cpu().numpy().flatten().astype(int),
            num_cls)

        if i_iter % 10 == 0:
            print('({}/{})'.format(i_iter + 1, int(args.num_steps)))
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(np.mean(iu)))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_train_acc = acc_overall
            avg_train_loss_seg1 = np.mean(total_loss_seg_value1)
            avg_train_loss_adv1 = np.mean(total_loss_adv_target_value1)
            avg_train_loss_dis1 = np.mean(total_loss_D_value1)
            avg_train_loss_con1 = np.mean(total_loss_con_value1)
            avg_train_loss_seg2 = np.mean(total_loss_seg_value2)
            avg_train_loss_adv2 = np.mean(total_loss_adv_target_value2)
            avg_train_loss_dis2 = np.mean(total_loss_D_value2)
            avg_train_loss_con2 = np.mean(total_loss_con_value2)

            print('avg_train_acc      :', avg_train_acc)
            print('avg_train_loss_seg1 :', avg_train_loss_seg1)
            print('avg_train_loss_adv1 :', avg_train_loss_adv1)
            print('avg_train_loss_dis1 :', avg_train_loss_dis1)
            print('avg_train_loss_con1 :', avg_train_loss_con1)
            print('avg_train_loss_seg2 :', avg_train_loss_seg2)
            print('avg_train_loss_adv2 :', avg_train_loss_adv2)
            print('avg_train_loss_dis2 :', avg_train_loss_dis2)
            print('avg_train_loss_con2 :', avg_train_loss_con2)

            writer['train'].add_scalar('log/mIoU', mIoU, i_iter)
            writer['train'].add_scalar('log/acc', avg_train_acc, i_iter)
            writer['train'].add_scalar('log1/loss_seg', avg_train_loss_seg1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_adv', avg_train_loss_adv1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_dis', avg_train_loss_dis1,
                                       i_iter)
            writer['train'].add_scalar('log1/loss_con', avg_train_loss_con1,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_seg', avg_train_loss_seg2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_adv', avg_train_loss_adv2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_dis', avg_train_loss_dis2,
                                       i_iter)
            writer['train'].add_scalar('log2/loss_con', avg_train_loss_con2,
                                       i_iter)

            hist = np.zeros((num_cls, num_cls))
            total_loss_seg_value1 = []
            total_loss_adv_target_value1 = []
            total_loss_D_value1 = []
            total_loss_con_value1 = []
            total_loss_seg_value2 = []
            total_loss_adv_target_value2 = []
            total_loss_D_value2 = []
            total_loss_con_value2 = []

            fig = plt.figure(figsize=(15, 15))

            labels = labels[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(331)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(337)
            images = images_orig[0].cpu().numpy().transpose((1, 2, 0))
            #             images += IMG_MEAN
            ax.imshow(images)
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(334)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            labels_tar = labels_tar[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(332)
            ax.imshow(print_palette(Image.fromarray(labels_tar).convert('L')))
            ax.axis("off")
            ax.set_title('tar_labels')

            ax = fig.add_subplot(338)
            ax.imshow(images_tar_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('tar_datas')

            _, pred_target2 = torch.max(pred_target2, dim=1)
            pred_target2 = pred_target2[0].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(335)
            ax.imshow(print_palette(
                Image.fromarray(pred_target2).convert('L')))
            ax.axis("off")
            ax.set_title('tar_predicts')

            print(policies[0], 'p1', rand_p1[0], 'p2', rand_p2[0],
                  'random_dir', random_dir[0])

            psuedo_label2_thre = psuedo_label2_thre[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(333)
            ax.imshow(
                print_palette(
                    Image.fromarray(psuedo_label2_thre).convert('L')))
            ax.axis("off")
            ax.set_title('psuedo_labels')

            ax = fig.add_subplot(339)
            ax.imshow(images_aug_orig[0].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('aug_datas')

            _, pred_target_aug2 = torch.max(pred_target_aug2, dim=1)
            pred_target_aug2 = pred_target_aug2[0].cpu().numpy().astype(
                np.float32)
            ax = fig.add_subplot(336)
            ax.imshow(
                print_palette(Image.fromarray(pred_target_aug2).convert('L')))
            ax.axis("off")
            ax.set_title('aug_predicts')

            #             plt.show()
            writer['train'].add_figure('image/',
                                       fig,
                                       global_step=i_iter,
                                       close=True)

        if i_iter % 500 == 0:
            loss1 = []
            loss2 = []
            for test_i, batch in enumerate(test_loader):

                images, labels = batch
                images_orig = images
                images = transform_batch(images, normalize_transform)
                images = Variable(images).cuda(args.gpu)

                pred1, pred2 = model(images)
                pred1 = interp_test(pred1)
                pred1 = pred1.detach()
                pred2 = interp_test(pred2)
                pred2 = pred2.detach()

                loss_seg1 = loss_calc(pred1, labels, args.gpu)
                loss_seg2 = loss_calc(pred2, labels, args.gpu)
                loss1.append(loss_seg1.item())
                loss2.append(loss_seg2.item())

                hist += fast_hist(
                    labels.cpu().numpy().flatten().astype(int),
                    torch.argmax(pred2,
                                 dim=1).cpu().numpy().flatten().astype(int),
                    num_cls)

            print('test')
            fig = plt.figure(figsize=(15, 15))
            labels = labels[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(311)
            ax.imshow(print_palette(Image.fromarray(labels).convert('L')))
            ax.axis("off")
            ax.set_title('labels')

            ax = fig.add_subplot(313)
            ax.imshow(images_orig[-1].cpu().numpy().transpose((1, 2, 0)))
            ax.axis("off")
            ax.set_title('datas')

            _, pred2 = torch.max(pred2, dim=1)
            pred2 = pred2[-1].cpu().numpy().astype(np.float32)
            ax = fig.add_subplot(312)
            ax.imshow(print_palette(Image.fromarray(pred2).convert('L')))
            ax.axis("off")
            ax.set_title('predicts')

            #             plt.show()

            writer['test'].add_figure('test_image/',
                                      fig,
                                      global_step=i_iter,
                                      close=True)

            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            mIoU = np.mean(iu)
            per_class = [[classes[i], acc] for i, acc in list(enumerate(iu))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls IoU :' + ('\n{:>14s} : {}') * 19).format(*per_class))
            print('mIoU : {:0.2f}'.format(mIoU))
            print('fwIoU : {:0.2f}'.format(fwIU))
            print('pixel acc : {:0.2f}'.format(acc_overall))
            per_class = [[classes[i], acc]
                         for i, acc in list(enumerate(acc_percls))]
            per_class = np.array(per_class).flatten()
            print(
                ('per cls acc :' + ('\n{:>14s} : {}') * 19).format(*per_class))

            avg_test_loss1 = np.mean(loss1)
            avg_test_loss2 = np.mean(loss2)
            avg_test_acc = acc_overall
            print('avg_test_loss2 :', avg_test_loss1)
            print('avg_test_loss1 :', avg_test_loss2)
            print('avg_test_acc   :', avg_test_acc)
            writer['test'].add_scalar('log1/loss_seg', avg_test_loss1, i_iter)
            writer['test'].add_scalar('log2/loss_seg', avg_test_loss2, i_iter)
            writer['test'].add_scalar('log/acc', avg_test_acc, i_iter)
            writer['test'].add_scalar('log/mIoU', mIoU, i_iter)

            hist = np.zeros((num_cls, num_cls))

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if max_mIoU < mIoU:
            max_mIoU = mIoU
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + 'best_iter' + '_D2.pth'))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=input_size,
                                              scale=args.random_scale,
                                              mirror=args.random_mirror,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=False,
        mirror=args.random_mirror,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    # Implemented by Bongjoon Hyun
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)
    #

    # Implemented by Bongjoon Hyun
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()
    #

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        # Implemented by Bongjoon Hyun
        bce_loss = torch.nn.MSELoss()
        #

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):
        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):
            # Implemented by Bongjoon Hyun
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            _, batch = next(trainloader_iter)
            images, labels, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu)
            loss_seg2 = loss_calc(pred2, labels, args.gpu)
            loss = (loss_seg2 + args.lambda_seg * loss_seg1) / args.iter_size
            loss.backward()

            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            _, batch = next(targetloader_iter)
            images, _, _ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_adv_target1 = bce_loss(D1_out, labels_source1)
            loss_adv_target2 = bce_loss(D2_out, labels_source2)

            loss = args.lambda_adv_target1 * loss_adv_target1 + \
                   args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size

            loss.backward()

            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D1_out = model_D1(F.softmax(pred1))
            D2_out = model_D2(F.softmax(pred2))

            labels_source1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(source_label)).cuda(args.gpu)
            labels_source2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(source_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_source1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_source2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D1_out = model_D1(F.softmax(pred_target1))
            D2_out = model_D2(F.softmax(pred_target2))

            labels_target1 = Variable(
                torch.FloatTensor(
                    D1_out.data.size()).fill_(target_label)).cuda(args.gpu)
            labels_target2 = Variable(
                torch.FloatTensor(
                    D2_out.data.size()).fill_(target_label)).cuda(args.gpu)

            loss_D1 = bce_loss(D1_out, labels_target1) / args.iter_size / 2
            loss_D2 = bce_loss(D2_out, labels_target2) / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()
            #

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))

        if i_iter >= args.num_steps_stop - 1:
            print
            'save model ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print
            'taking snapshot ...'
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))
def main():
    """Create the model and start the training."""

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = args.gpu

    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)

        new_params = model.state_dict().copy()

        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if args.num_classes != 19 and i_parts[1] != 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        # if args.num_classes !=19:
        # print i_parts
        model.load_state_dict(new_params)
    start_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
    writer_dir = os.path.join("./logs/", args.name, start_time)
    writer = tensorboard.SummaryWriter(writer_dir)

    model.train()
    model.cuda(args.gpu)

    cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)

    model_D1.train()
    model_D1.cuda(args.gpu)

    model_D2.train()
    model_D2.cuda(args.gpu)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        MulitviewSegLoader(
            num_classes=args.num_classes,
            root=args.data_dir,
            number_views=2,
            view_idx=1,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size,
            # scale=args.random_scale, mirror=args.random_mirror,
            img_mean=IMG_MEAN),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    trainloader_iter = iter(data_loader_cycle(trainloader))

    targetloader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_target,
            num_classes=args.num_classes,
            number_views=1,
            view_idx=0,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    def mdl_val_func(x):
        return interp_target(model(x)[1])

    targetloader_iter = iter(data_loader_cycle(targetloader))
    val_loader = data.DataLoader(
        MulitviewSegLoader(
            root=args.data_dir_val,
            number_views=1,
            view_idx=0,
            num_classes=args.num_classes,
            # max_iters=args.num_steps * args.iter_size * args.batch_size,
            # crop_size=input_size_target,
            # scale=False,
            # mirror=args.random_mirror,
            img_mean=IMG_MEAN,
            # set=args.set
        ),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True)

    criterion = CrossEntropyLoss2d().cuda(args.gpu)
    valhelper = ValHelper(gpu=args.gpu,
                          model=mdl_val_func,
                          val_loader=val_loader,
                          loss=criterion,
                          writer=writer)
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source

            batch = next(trainloader_iter)
            images, labels, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred1, pred2 = model(images)
            pred1 = interp(pred1)
            pred2 = interp(pred2)

            loss_seg1 = loss_calc(pred1, labels, args.gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, args.gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            loss_seg_value1 += loss_seg1.data.cpu().numpy() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().numpy() / args.iter_size

            # train with target

            batch = next(targetloader_iter)
            images, *_ = batch
            images = Variable(images).cuda(args.gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().numpy(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1))
            D_out2 = model_D2(F.softmax(pred2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1))
            D_out2 = model_D2(F.softmax(pred_target2))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(
                            args.gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().numpy()
            loss_D_value2 += loss_D2.data.cpu().numpy()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()
        if i_iter % args.val_steps == 0 and i_iter:
            model.eval()
            log = valhelper.valid_epoch(i_iter)
            print('log: {}'.format(log))
            model.train()

        if i_iter % 10 == 0 and i_iter:
            print('exp = {}'.format(args.snapshot_dir))
            print(
                'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
                .format(i_iter, args.num_steps, loss_seg_value1,
                        loss_seg_value2, loss_adv_target_value1,
                        loss_adv_target_value2, loss_D_value1, loss_D_value2))
            writer.add_scalar(f'train/loss_seg_value1', loss_seg_value1,
                              i_iter)
            writer.add_scalar(f'train/loss_seg_value2', loss_seg_value2,
                              i_iter)
            writer.add_scalar(f'train/loss_adv_target_value1',
                              loss_adv_target_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value1', loss_D_value1, i_iter)
            writer.add_scalar(f'train/loss_D_value2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D2.pth'))