예제 #1
0
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.start_iter = args.start_iter
        self.num_steps = args.num_steps
        self.num_classes = args.num_classes
        self.preheat = self.num_steps/20  # damping instead of early stopping
        self.source_label = 0
        self.target_label = 1
        self.best_miou = 0
        # TODO: CHANGE LOSS for LSGAN
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        #self.bce_loss = torch.nn.MSELoss() #LSGAN
        self.weighted_bce_loss = WeightedBCEWithLogitsLoss()
        self.aux_acc = AverageMeter()
        self.save_path = args.prediction_dir  # dir to save class mIoU when validating model
        self.losses = {'seg': list(),'seg_t': list(), 'adv': list(), 'weight': list(), 'ds': list(), 'dt': list(), 'aux': list()}
        self.rotations = [0, 90, 180, 270]

        cudnn.enabled = True
        #cudnn.benchmark = True

        # set up models
        if args.model.name == 'DeepLab':
            self.model = Res_Deeplab(num_classes=args.num_classes, restore_from=args.model.restore_from)
            self.optimizer = optim.SGD(self.model.optim_parameters(args.model.optimizer), lr=args.model.optimizer.lr, momentum=args.model.optimizer.momentum, weight_decay=args.model.optimizer.weight_decay)
        if args.model.name == 'ErfNet':
            self.model = ERFNet(args.num_classes)  # To add image-net pre-training and double classificator
            self.optimizer = optim.SGD(self.model.optim_parameters(args.model.optimizer), lr=args.model.optimizer.lr, momentum=args.model.optimizer.momentum, weight_decay=args.model.optimizer.weight_decay)

        if args.method.adversarial:
            self.model_D = discriminator(name=args.discriminator.name, num_classes=args.num_classes, restore_from=args.discriminator.restore_from)
            self.optimizer_D = optim.Adam(self.model_D.parameters(), lr=args.discriminator.optimizer.lr, betas=(0.9, 0.99))
        if args.method.self:
            self.model_A = auxiliary(name=args.auxiliary.name, input_dim=args.auxiliary.classes, aux_classes=args.auxiliary.aux_classes, restore_from=args.auxiliary.restore_from)
            self.optimizer_A = optim.Adam(self.model_A.parameters(), lr=args.auxiliary.optimizer.lr, betas=(0.9, 0.99))
            self.aux_loss = nn.CrossEntropyLoss()
예제 #2
0
def main(args):
    ## fix random_seed
    fixRandomSeed(1)

    ## cuda setting
    cudnn.benchmark = True
    cudnn.enabled = True
    device = torch.device('cuda:' + str(args.gpuid))
    torch.cuda.set_device(device)

    ## Logger setting
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    print('logs_dir=', args.logs_dir)
    print('args : ', args)

    ## get dataset & dataloader:
    dataset, source_num_classes, source_train_loader, \
    target_train_loader, query_loader, gallery_loader = get_data(args.data_dir, args.source,args.target,
                                                                 args.source_train_path, args.target_train_path,
                                                                 args.source_extension,args.target_extension,
                                                                 args.height, args.width,
                                                                 args.batch_size, args.re, args.workers)

    h, w = map(int, [args.height, args.width])
    input_size_source = (h, w)
    input_size_target = (h, w)

    # cudnn.enabled = True

    # Create Network
    # model = Res_Deeplab(num_classes=args.num_classes)
    model = Res_Deeplab(num_classes=source_num_classes)
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()

    ## adapte new_params's layers / classes to saved_state_dict
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]

    if args.restore_from[:4] == './mo':
        model.load_state_dict(new_params)
    else:
        model.load_state_dict(saved_state_dict)

    ## set mode = train and moves the params of model to GPU
    model.train()
    model.cuda(args.gpu)

    # cudnn.benchmark = True

    # Init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    # =============================================================================
    #    #for retrain
    #    saved_state_dict_D = torch.load(RESTORE_FROM_D)
    #    model_D.load_state_dict(saved_state_dict_D)
    # =============================================================================

    model_D.train()
    model_D.cuda(args.gpu)

    # if not os.path.exists(args.snapshot_dir):
    #     os.makedirs(args.snapshot_dir)

    if args.source == 'GTA5':
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_source,
            scale=True,
            mirror=True,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
    else:
        trainloader = data.DataLoader(SYNTHIADataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_source,
            scale=True,
            mirror=True,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=True,
        mirror=True,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    weighted_bce_loss = WeightedBCEWithLogitsLoss()

    interp_source = nn.Upsample(size=(input_size_source[1],
                                      input_size_source[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # Labels for Adversarial Training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        damping = (1 - i_iter / NUM_STEPS)

        # ======================================================================================
        # train G
        # ======================================================================================

        # Remove Grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s, _, _, _ = batch
        images_s = Variable(images_s).cuda(args.gpu)
        pred_source1, pred_source2 = model(images_s)
        pred_source1 = interp_source(pred_source1)
        pred_source2 = interp_source(pred_source2)

        # Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, args.gpu) +
                    loss_calc(pred_source2, labels_s, args.gpu))
        loss_seg.backward()

        # Train with Target
        _, batch = next(targetloader_iter)
        images_t, _, _, _ = batch
        images_t = Variable(images_t).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_t)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        weight_map = weightmap(F.softmax(pred_target1, dim=1),
                               F.softmax(pred_target2, dim=1))

        D_out = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_adv = weighted_bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda(args.gpu),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_adv = bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv = loss_adv * Lambda_adv * damping
        loss_adv.backward()

        # Weight Discrepancy Loss
        W5 = None
        W6 = None
        if args.model == 'ResNet':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        loss_weight = (torch.matmul(W5, W6) /
                       (torch.norm(W5) * torch.norm(W6)) + 1
                       )  # +1 is for a positive loss
        loss_weight = loss_weight * Lambda_weight * damping * 2
        loss_weight.backward()

        # ======================================================================================
        # train D
        # ======================================================================================

        # Bring back Grads in D
        for param in model_D.parameters():
            param.requires_grad = True

        # Train with Source
        pred_source1 = pred_source1.detach()
        pred_source2 = pred_source2.detach()

        D_out_s = interp_source(
            model_D(F.softmax(pred_source1 + pred_source2, dim=1)))

        loss_D_s = bce_loss(
            D_out_s,
            Variable(
                torch.FloatTensor(
                    D_out_s.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D_s.backward()

        # Train with Target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()
        weight_map = weight_map.detach()

        D_out_t = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_D_t = weighted_bce_loss(
                D_out_t,
                Variable(
                    torch.FloatTensor(
                        D_out_t.data.size()).fill_(target_label)).cuda(
                            args.gpu), weight_map, Epsilon, Lambda_local)
        else:
            loss_D_t = bce_loss(
                D_out_t,
                Variable(
                    torch.FloatTensor(
                        D_out_t.data.size()).fill_(target_label)).cuda(
                            args.gpu))

        loss_D_t.backward()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_adv = {3:.4f}, loss_weight = {4:.4f}, loss_D_s = {5:.4f} loss_D_t = {6:.4f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv, loss_weight,
                    loss_D_s, loss_D_t))

        f_loss = open(osp.join(args.snapshot_dir, 'loss.txt'), 'a')
        f_loss.write('{0:.4f} {1:.4f} {2:.4f} {3:.4f} {4:.4f}\n'.format(
            loss_seg, loss_adv, loss_weight, loss_D_s, loss_D_t))
        f_loss.close()

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    ## create dataloader
    dataset, source_num_classes, source_train_loader, target_train_loader, query_loader, gallery_loader = get_data(
        args.data_dir, args.source, args.target, args.source_train_path,
        args.target_train_path, args.source_extension, args.target_extension,
        args.height, args.width, args.batch_size, args.re, args.workers)
    h, w = map(int, args.input_size_source.split(','))
    input_size_source = (h, w)
    input_size_target = (h, w)
예제 #3
0
def main():
    """Create the model and start the training."""
    save_dir = osp.join(args.snapshot_dir, args.method)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    writer = SummaryWriter(save_dir)
    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)
    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)
    cudnn.enabled = True
    # Create network
    if args.backbone == 'resnet':
        model = Deeplab_Res101(num_classes=args.num_classes)
    if args.resume:
        print("Resuming from ==>>", args.resume)
        state_dict = torch.load(args.resume)
        model.load_state_dict(state_dict)
    else:
        if args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        else:
            saved_state_dict = torch.load(args.restore_from)
        new_params = model.state_dict().copy()
        for i in saved_state_dict:
            # Scale.layer5.conv2d_list.3.weight
            i_parts = i.split('.')
            # print i_parts
            if not args.num_classes == 19 or not i_parts[1] == 'layer5':
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
                # print i_parts
        model.load_state_dict(new_params)
    model.train()
    model.cuda()
    cudnn.benchmark = True

    # init D
    model_D = EightwayASADiscriminator(num_classes=args.num_classes)
    model_D.train()
    model_D.cuda()

    print(model_D)
    pprint(vars(args))
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.batch_size,
                                              img_size=input_size),
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)
    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.batch_size,
        img_size=input_size_target,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    targetloader_iter = enumerate(targetloader)
    # implement model.optim_parameters(args) to handle different models' lr setting
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()
    bce_loss = torch.nn.BCEWithLogitsLoss()
    weight_bce_loss = WeightedBCEWithLogitsLoss()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    source_label = 0
    target_label = 1
    start = timeit.default_timer()
    loss_seg_value = 0
    loss_adv_target_value = 0
    loss_D_value = 0
    for i_iter in range(args.num_steps):
        damping = (1 - i_iter / args.num_steps)
        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        # train G
        # don't accumulate grads in D
        for param in model_D.parameters():
            param.requires_grad = False
        # train with source
        _, batch = next(trainloader_iter)
        src_img, labels, _, _ = batch
        src_img = Variable(src_img).cuda()
        pred = model(src_img)
        pred = interp(pred)
        loss_seg = loss_calc(pred, labels)
        loss_seg.backward()
        loss_seg_value += loss_seg.item()

        # train with target
        _, batch = next(targetloader_iter)
        tar_img, _, _, _ = batch
        tar_img = Variable(tar_img).cuda()
        pred_target = model(tar_img)
        pred_target = interp_target(pred_target)
        D_out = model_D(F.softmax(pred_target, dim=1))
        loss_adv_target = bce_loss(
            D_out,
            torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda())
        loss_adv = loss_adv_target * args.lambda_adv_target1 * damping
        loss_adv.backward()
        loss_adv_target_value += loss_adv_target.item()
        # train D
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True
        # train with source
        pred = pred.detach()
        D_out = model_D(F.softmax(pred, dim=1))
        loss_D1 = bce_loss(
            D_out,
            torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda())
        loss_D1 = loss_D1 / 2
        loss_D1.backward()
        loss_D_value += loss_D1.item()
        # train with target
        pred_target = pred_target.detach()
        D_out1 = model_D(F.softmax(pred_target, dim=1))
        loss_D1 = bce_loss(
            D_out1,
            torch.FloatTensor(D_out1.data.size()).fill_(target_label).cuda())
        loss_D1 = loss_D1 / 2
        loss_D1.backward()
        loss_D_value += loss_D1.item()
        optimizer.step()
        optimizer_D.step()
        current = timeit.default_timer()

        if i_iter % 50 == 0:
            print(
                'iter = {0:6d}/{1:6d}, loss_seg1 = {2:.3f}  loss_adv1 = {3:.3f}, loss_D1 = {4:.3f} ({5:.3f}/iter)'
                .format(i_iter, args.num_steps, loss_seg_value / 50,
                        loss_adv_target_value / 50, loss_D_value / 50,
                        (current - start) / (i_iter + 1)))
            writer.add_scalar('learning_rate', lr, i_iter)
            writer.add_scalars(
                "Loss", {
                    "Seg": loss_seg_value,
                    "Adv": loss_adv_target_value,
                    "Disc": loss_D_value
                }, i_iter)
            loss_seg_value = 0
            loss_adv_target_value = 0
            loss_D_value = 0

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(model.state_dict(),
                       osp.join(save_dir, 'GTA5KLASA_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(save_dir, 'GTA5KLASA_' + str(i_iter) + '_D.pth'))

        if (i_iter + 1) >= args.num_steps_stop:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(save_dir,
                         'GTA5KLASA_' + str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(save_dir,
                         'GTA5KLASA_' + str(args.num_steps_stop) + '_D.pth'))
예제 #4
0
def main():
    """Create the model and start the training."""

    cudnn.enabled = True
    cudnn.benchmark = True

    device = torch.device("cuda" if not args.cpu else "cpu")

    random.seed(args.random_seed)

    snapshot_dir = os.path.join(args.snapshot_dir, args.experiment)
    log_dir = os.path.join(args.log_dir, args.experiment)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(snapshot_dir, exist_ok=True)

    log_file = os.path.join(log_dir, 'log.txt')

    init_log(log_file, args)

    # =============================================================================
    # INIT G
    # =============================================================================
    if MODEL == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes,
                            restore_from=args.restore_from)
    model.train()
    model.to(device)

    # =============================================================================
    # INIT D
    # =============================================================================

    model_D = FCDiscriminator(num_classes=args.num_classes)

    # saved_state_dict_D = torch.load(RESTORE_FROM_D) #for retrain
    # model_D.load_state_dict(saved_state_dict_D)

    model_D.train()
    model_D.to(device)

    # DataLoaders
    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=args.input_size_source,
                                              scale=True,
                                              mirror=True,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=args.input_size_target,
        scale=True,
        mirror=True,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    # Optimizers
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # Losses
    bce_loss = torch.nn.BCEWithLogitsLoss()
    weighted_bce_loss = WeightedBCEWithLogitsLoss()

    interp_source = nn.Upsample(size=(args.input_size_source[1],
                                      args.input_size_source[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_target = nn.Upsample(size=(args.input_size_target[1],
                                      args.input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # Labels for Adversarial Training
    source_label = 0
    target_label = 1

    # ======================================================================================
    # Start training
    # ======================================================================================
    print('###########   TRAINING STARTED  ############')
    start = time.time()

    for i_iter in range(args.start_from_iter, args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        damping = (1 - (i_iter) / NUM_STEPS)

        # ======================================================================================
        # train G
        # ======================================================================================

        # Remove Grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s, _, _ = batch
        images_s = images_s.to(device)
        pred_source1, pred_source2 = model(images_s)

        pred_source1 = interp_source(pred_source1)
        pred_source2 = interp_source(pred_source2)

        # Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, device) +
                    loss_calc(pred_source2, labels_s, device))
        loss_seg.backward()

        # Train with Target
        _, batch = next(targetloader_iter)
        images_t, _, _ = batch
        images_t = images_t.to(device)

        pred_target1, pred_target2 = model(images_t)

        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        weight_map = weightmap(F.softmax(pred_target1, dim=1),
                               F.softmax(pred_target2, dim=1))

        D_out = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if i_iter > PREHEAT_STEPS:
            loss_adv = weighted_bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_adv = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))

        loss_adv = loss_adv * Lambda_adv * damping
        loss_adv.backward()

        # Weight Discrepancy Loss
        W5 = None
        W6 = None
        if args.model == 'ResNet':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        loss_weight = (torch.matmul(W5, W6) /
                       (torch.norm(W5) * torch.norm(W6)) + 1
                       )  # +1 is for a positive loss
        loss_weight = loss_weight * Lambda_weight * damping * 2
        loss_weight.backward()

        # ======================================================================================
        # train D
        # ======================================================================================

        # Bring back Grads in D
        for param in model_D.parameters():
            param.requires_grad = True

        # Train with Source
        pred_source1 = pred_source1.detach()
        pred_source2 = pred_source2.detach()

        D_out_s = interp_source(
            model_D(F.softmax(pred_source1 + pred_source2, dim=1)))

        loss_D_s = bce_loss(
            D_out_s,
            torch.FloatTensor(
                D_out_s.data.size()).fill_(source_label).to(device))

        loss_D_s.backward()

        # Train with Target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()
        weight_map = weight_map.detach()

        D_out_t = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_D_t = weighted_bce_loss(
                D_out_t,
                torch.FloatTensor(
                    D_out_t.data.size()).fill_(target_label).to(device),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_D_t = bce_loss(
                D_out_t,
                torch.FloatTensor(
                    D_out_t.data.size()).fill_(target_label).to(device))

        loss_D_t.backward()

        optimizer.step()
        optimizer_D.step()

        if (i_iter) % 10 == 0:
            log_message(
                'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_adv = {3:.4f}, loss_weight = {4:.4f}, loss_D_s = {5:.4f} loss_D_t = {6:.4f}'
                .format(i_iter, args.num_steps, loss_seg, loss_adv,
                        loss_weight, loss_D_s, loss_D_t), log_file)

        if (i_iter % args.save_pred_every == 0
                and i_iter != 0) or i_iter == args.num_steps - 1:
            i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1  # for last iter
            print('saving weights...')
            torch.save(model.state_dict(),
                       osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    end = time.time()
    log_message(
        'Total training time: {} days, {} hours, {} min, {} sec '.format(
            int((end - start) / 86400), int((end - start) / 3600),
            int((end - start) / 60 % 60), int((end - start) % 60)), log_file)
    print('### Experiment: ' + args.experiment + ' finished ###')