Esempio n. 1
0
def main():
    """Create the model and start the evaluation process."""

    for i in range(1, 50):
        model_path = './snapshots/GTA2Cityscapes/GTA5_{0:d}.pth'.format(i *
                                                                        2000)
        save_path = './result/GTA2Cityscapes_{0:d}'.format(i * 2000)
        args = get_arguments()

        gpu0 = args.gpu

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        model = Res_Deeplab(num_classes=args.num_classes)

        saved_state_dict = torch.load(model_path)
        model.load_state_dict(saved_state_dict)

        model.eval()
        model.cuda(gpu0)

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        with torch.no_grad():
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processd' % index)
                image, _, _, name = batch
                output1, output2 = model(Variable(image).cuda(gpu0))

                output = interp(output1 + output2).cpu().data[0].numpy()

                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                name = name[0].split('/')[-1]
                output.save('%s/%s' % (save_path, name))

                output_col.save('%s/%s_color.png' %
                                (save_path, name.split('.')[0]))

        print(save_path)
Esempio n. 2
0
def main():
    """Create the model and start the evaluation process."""

    args = get_arguments()

    os.makedirs(args.save, exist_ok=True)

    device = torch.device("cuda" if not args.cpu else "cpu")

    model = Res_Deeplab(num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.to(device)

    testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set),
                                    batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True)

    print('### STARTING EVALUATING ###')
    print('total to process: %d' % len(testloader))
    with torch.no_grad():
        for index, batch in enumerate(testloader):
            if index % 100 == 0:
                print('%d processed' % index)
            image, _, name, = batch
            output1, output2 = model(image.to(device))

            output = interp(output1 + output2).cpu().data[0].numpy()
            
            output = output.transpose(1,2,0)
            output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)
    
            output_col = colorize_mask(output)
            output = Image.fromarray(output)
    
            name = name[0].split('/')[-1]
            output.save('%s/%s' % (args.save, name))

            output_col.save('%s/%s_color.png' % (args.save, name.split('.')[0]))

    print('### EVALUATING FINISHED ###')
Esempio n. 3
0
def main(args):
    ## fix random_seed
    fixRandomSeed(1)

    ## cuda setting
    cudnn.benchmark = True
    cudnn.enabled = True
    device = torch.device('cuda:' + str(args.gpuid))
    torch.cuda.set_device(device)

    ## Logger setting
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
    print('logs_dir=', args.logs_dir)
    print('args : ', args)

    ## get dataset & dataloader:
    dataset, source_num_classes, source_train_loader, \
    target_train_loader, query_loader, gallery_loader = get_data(args.data_dir, args.source,args.target,
                                                                 args.source_train_path, args.target_train_path,
                                                                 args.source_extension,args.target_extension,
                                                                 args.height, args.width,
                                                                 args.batch_size, args.re, args.workers)

    h, w = map(int, [args.height, args.width])
    input_size_source = (h, w)
    input_size_target = (h, w)

    # cudnn.enabled = True

    # Create Network
    # model = Res_Deeplab(num_classes=args.num_classes)
    model = Res_Deeplab(num_classes=source_num_classes)
    if args.restore_from[:4] == 'http':
        saved_state_dict = model_zoo.load_url(args.restore_from)
    else:
        saved_state_dict = torch.load(args.restore_from)
    new_params = model.state_dict().copy()

    ## adapte new_params's layers / classes to saved_state_dict
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not args.num_classes == 19 or not i_parts[1] == 'layer5':
            new_params['.'.join(i_parts[1:])] = saved_state_dict[i]

    if args.restore_from[:4] == './mo':
        model.load_state_dict(new_params)
    else:
        model.load_state_dict(saved_state_dict)

    ## set mode = train and moves the params of model to GPU
    model.train()
    model.cuda(args.gpu)

    # cudnn.benchmark = True

    # Init D
    model_D = FCDiscriminator(num_classes=args.num_classes)
    # =============================================================================
    #    #for retrain
    #    saved_state_dict_D = torch.load(RESTORE_FROM_D)
    #    model_D.load_state_dict(saved_state_dict_D)
    # =============================================================================

    model_D.train()
    model_D.cuda(args.gpu)

    # if not os.path.exists(args.snapshot_dir):
    #     os.makedirs(args.snapshot_dir)

    if args.source == 'GTA5':
        trainloader = data.DataLoader(GTA5DataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_source,
            scale=True,
            mirror=True,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
    else:
        trainloader = data.DataLoader(SYNTHIADataSet(
            args.data_dir,
            args.data_list,
            max_iters=args.num_steps * args.iter_size * args.batch_size,
            crop_size=input_size_source,
            scale=True,
            mirror=True,
            mean=IMG_MEAN),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=input_size_target,
        scale=True,
        mirror=True,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    targetloader_iter = enumerate(targetloader)

    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    bce_loss = torch.nn.BCEWithLogitsLoss()
    weighted_bce_loss = WeightedBCEWithLogitsLoss()

    interp_source = nn.Upsample(size=(input_size_source[1],
                                      input_size_source[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # Labels for Adversarial Training
    source_label = 0
    target_label = 1

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        damping = (1 - i_iter / NUM_STEPS)

        # ======================================================================================
        # train G
        # ======================================================================================

        # Remove Grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s, _, _, _ = batch
        images_s = Variable(images_s).cuda(args.gpu)
        pred_source1, pred_source2 = model(images_s)
        pred_source1 = interp_source(pred_source1)
        pred_source2 = interp_source(pred_source2)

        # Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, args.gpu) +
                    loss_calc(pred_source2, labels_s, args.gpu))
        loss_seg.backward()

        # Train with Target
        _, batch = next(targetloader_iter)
        images_t, _, _, _ = batch
        images_t = Variable(images_t).cuda(args.gpu)

        pred_target1, pred_target2 = model(images_t)
        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        weight_map = weightmap(F.softmax(pred_target1, dim=1),
                               F.softmax(pred_target2, dim=1))

        D_out = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_adv = weighted_bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda(args.gpu),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_adv = bce_loss(
                D_out,
                Variable(
                    torch.FloatTensor(
                        D_out.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_adv = loss_adv * Lambda_adv * damping
        loss_adv.backward()

        # Weight Discrepancy Loss
        W5 = None
        W6 = None
        if args.model == 'ResNet':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        loss_weight = (torch.matmul(W5, W6) /
                       (torch.norm(W5) * torch.norm(W6)) + 1
                       )  # +1 is for a positive loss
        loss_weight = loss_weight * Lambda_weight * damping * 2
        loss_weight.backward()

        # ======================================================================================
        # train D
        # ======================================================================================

        # Bring back Grads in D
        for param in model_D.parameters():
            param.requires_grad = True

        # Train with Source
        pred_source1 = pred_source1.detach()
        pred_source2 = pred_source2.detach()

        D_out_s = interp_source(
            model_D(F.softmax(pred_source1 + pred_source2, dim=1)))

        loss_D_s = bce_loss(
            D_out_s,
            Variable(
                torch.FloatTensor(
                    D_out_s.data.size()).fill_(source_label)).cuda(args.gpu))

        loss_D_s.backward()

        # Train with Target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()
        weight_map = weight_map.detach()

        D_out_t = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_D_t = weighted_bce_loss(
                D_out_t,
                Variable(
                    torch.FloatTensor(
                        D_out_t.data.size()).fill_(target_label)).cuda(
                            args.gpu), weight_map, Epsilon, Lambda_local)
        else:
            loss_D_t = bce_loss(
                D_out_t,
                Variable(
                    torch.FloatTensor(
                        D_out_t.data.size()).fill_(target_label)).cuda(
                            args.gpu))

        loss_D_t.backward()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_adv = {3:.4f}, loss_weight = {4:.4f}, loss_D_s = {5:.4f} loss_D_t = {6:.4f}'
            .format(i_iter, args.num_steps, loss_seg, loss_adv, loss_weight,
                    loss_D_s, loss_D_t))

        f_loss = open(osp.join(args.snapshot_dir, 'loss.txt'), 'a')
        f_loss.write('{0:.4f} {1:.4f} {2:.4f} {3:.4f} {4:.4f}\n'.format(
            loss_seg, loss_adv, loss_weight, loss_D_s, loss_D_t))
        f_loss.close()

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'GTA5_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    ## create dataloader
    dataset, source_num_classes, source_train_loader, target_train_loader, query_loader, gallery_loader = get_data(
        args.data_dir, args.source, args.target, args.source_train_path,
        args.target_train_path, args.source_extension, args.target_extension,
        args.height, args.width, args.batch_size, args.re, args.workers)
    h, w = map(int, args.input_size_source.split(','))
    input_size_source = (h, w)
    input_size_target = (h, w)
Esempio n. 4
0
def main():
    """Create the model and start the training."""

    cudnn.enabled = True
    cudnn.benchmark = True

    device = torch.device("cuda" if not args.cpu else "cpu")

    snapshot_dir = os.path.join(args.snapshot_dir, args.experiment)
    os.makedirs(snapshot_dir, exist_ok=True)

    log_file = os.path.join(args.log_dir, '%s.txt' % args.experiment)
    init_log(log_file, args)

    # =============================================================================
    # INIT G
    # =============================================================================
    if MODEL == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes,
                            restore_from=args.restore_from)
    model.train()
    model.to(device)

    # DataLoaders
    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=args.input_size_source,
                                              scale=True,
                                              mirror=True,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    #   trainloader = data.DataLoader(cityscapesDataSetLabel(args.data_dir_target, './dataset/cityscapes_list/info.json', args.data_list_target,args.data_list_label_target,
    #                                                   max_iters=args.num_steps * args.iter_size * args.batch_size, crop_size=args.input_size_target,
    #                                                   mean=IMG_MEAN, set=args.set), batch_size=args.batch_size, shuffle=True,num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)

    # Optimizers
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    interp = nn.Upsample(size=(args.input_size_source[1],
                               args.input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    #interp = nn.Upsample(size=(args.input_size_target[1], args.input_size_target[0]), mode='bilinear', align_corners=True)

    # ======================================================================================
    # Start training
    # ======================================================================================
    log_message('###########   TRAINING STARTED  ############', log_file)
    start = time.time()

    for i_iter in range(args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        # ======================================================================================
        # train G
        # ======================================================================================

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s = batch
        images_s = images_s.to(device)

        pred_source1, pred_source2 = model(images_s)

        pred_source1 = interp(pred_source1)
        pred_source2 = interp(pred_source2)
        # Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, device) +
                    loss_calc(pred_source2, labels_s, device))
        loss_seg.backward()

        optimizer.step()

        if i_iter % 10 == 0:
            log_message(
                'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f}'.format(
                    i_iter, args.num_steps - 1, loss_seg), log_file)

        if (i_iter % args.save_pred_every == 0
                and i_iter != 0) or i_iter == args.num_steps - 1:
            print('saving weights...')
            torch.save(model.state_dict(),
                       osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))

    end = time.time()
    log_message(
        'Total training time: {} days, {} hours, {} min, {} sec '.format(
            int((end - start) / 86400), int((end - start) / 3600),
            int((end - start) / 60 % 60), int((end - start) % 60)), log_file)
    print('### Experiment: ' + args.experiment + ' Finished ###')
def main(args):
    """Create the model and start the evaluation process."""

    save_dir = os.path.join(args.save, args.experiment)
    model_dir = os.path.join(args.restore_from, args.experiment)
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if not args.cpu else "cpu")
    start = time.time()

    n_files = len([name for name in os.listdir(model_dir)])

    if args.d:
        n_files = int(n_files / 2)
    if args.a:
        n_files = int(n_files / 3)

    for i in range(1, n_files + 1):
        model_path = os.path.join(model_dir,
                                  'GTA5_{0:d}.pth'.format(i * args.save_step))
        save_path = os.path.join(save_dir, '{0:d}'.format(i * args.save_step))
        os.makedirs(save_path, exist_ok=True)

        print('#### Evaluating model: ' + str(i) + '/' + str(n_files) +
              ' ####')

        model = Res_Deeplab(num_classes=args.num_classes)

        saved_state_dict = torch.load(model_path)
        model.load_state_dict(saved_state_dict)

        model.eval()
        model.to(device)

        testloader = data.DataLoader(cityscapesDataSet(args.data_dir,
                                                       args.data_list,
                                                       crop_size=(1024, 512),
                                                       mean=IMG_MEAN,
                                                       scale=False,
                                                       mirror=False,
                                                       set=args.set),
                                     batch_size=1,
                                     shuffle=False,
                                     pin_memory=True)

        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)

        with torch.no_grad():
            for index, batch in enumerate(testloader):
                if index % 100 == 0:
                    print('%d processed' % index)
                image, _, name = batch
                output1, output2 = model(image.to(device))

                output = interp(output1 + output2).cpu().data[0].numpy()

                output = output.transpose(1, 2, 0)
                output = np.asarray(np.argmax(output, axis=2), dtype=np.uint8)

                output_col = colorize_mask(output)
                output = Image.fromarray(output)

                name = name[0].split('/')[-1]
                output.save('%s/%s' % (save_path, name))

                output_col.save('%s/%s_color.png' %
                                (save_path, name.split('.')[0]))

        print(save_path)
    end = time.time()
    print('Total time: {} min, {} sec '.format(int((end - start) / 60 % 60),
                                               int((end - start) % 60)))
Esempio n. 6
0
def main():
    """Create the model and start the training."""

    cudnn.enabled = True
    cudnn.benchmark = True

    device = torch.device("cuda" if not args.cpu else "cpu")

    random.seed(args.random_seed)

    snapshot_dir = os.path.join(args.snapshot_dir, args.experiment)
    log_dir = os.path.join(args.log_dir, args.experiment)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(snapshot_dir, exist_ok=True)

    log_file = os.path.join(log_dir, 'log.txt')

    init_log(log_file, args)

    # =============================================================================
    # INIT G
    # =============================================================================
    if MODEL == 'ResNet':
        model = Res_Deeplab(num_classes=args.num_classes,
                            restore_from=args.restore_from)
    model.train()
    model.to(device)

    # =============================================================================
    # INIT D
    # =============================================================================

    model_D = FCDiscriminator(num_classes=args.num_classes)

    # saved_state_dict_D = torch.load(RESTORE_FROM_D) #for retrain
    # model_D.load_state_dict(saved_state_dict_D)

    model_D.train()
    model_D.to(device)

    # DataLoaders
    trainloader = data.DataLoader(GTA5DataSet(args.data_dir,
                                              args.data_list,
                                              max_iters=args.num_steps *
                                              args.iter_size * args.batch_size,
                                              crop_size=args.input_size_source,
                                              scale=True,
                                              mirror=True,
                                              mean=IMG_MEAN),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    trainloader_iter = enumerate(trainloader)

    targetloader = data.DataLoader(cityscapesDataSet(
        args.data_dir_target,
        args.data_list_target,
        max_iters=args.num_steps * args.iter_size * args.batch_size,
        crop_size=args.input_size_target,
        scale=True,
        mirror=True,
        mean=IMG_MEAN,
        set=args.set),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    targetloader_iter = enumerate(targetloader)

    # Optimizers
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # Losses
    bce_loss = torch.nn.BCEWithLogitsLoss()
    weighted_bce_loss = WeightedBCEWithLogitsLoss()

    interp_source = nn.Upsample(size=(args.input_size_source[1],
                                      args.input_size_source[0]),
                                mode='bilinear',
                                align_corners=True)
    interp_target = nn.Upsample(size=(args.input_size_target[1],
                                      args.input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # Labels for Adversarial Training
    source_label = 0
    target_label = 1

    # ======================================================================================
    # Start training
    # ======================================================================================
    print('###########   TRAINING STARTED  ############')
    start = time.time()

    for i_iter in range(args.start_from_iter, args.num_steps):

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        damping = (1 - (i_iter) / NUM_STEPS)

        # ======================================================================================
        # train G
        # ======================================================================================

        # Remove Grads in D
        for param in model_D.parameters():
            param.requires_grad = False

        # Train with Source
        _, batch = next(trainloader_iter)
        images_s, labels_s, _, _ = batch
        images_s = images_s.to(device)
        pred_source1, pred_source2 = model(images_s)

        pred_source1 = interp_source(pred_source1)
        pred_source2 = interp_source(pred_source2)

        # Segmentation Loss
        loss_seg = (loss_calc(pred_source1, labels_s, device) +
                    loss_calc(pred_source2, labels_s, device))
        loss_seg.backward()

        # Train with Target
        _, batch = next(targetloader_iter)
        images_t, _, _ = batch
        images_t = images_t.to(device)

        pred_target1, pred_target2 = model(images_t)

        pred_target1 = interp_target(pred_target1)
        pred_target2 = interp_target(pred_target2)

        weight_map = weightmap(F.softmax(pred_target1, dim=1),
                               F.softmax(pred_target2, dim=1))

        D_out = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if i_iter > PREHEAT_STEPS:
            loss_adv = weighted_bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_adv = bce_loss(
                D_out,
                torch.FloatTensor(
                    D_out.data.size()).fill_(source_label).to(device))

        loss_adv = loss_adv * Lambda_adv * damping
        loss_adv.backward()

        # Weight Discrepancy Loss
        W5 = None
        W6 = None
        if args.model == 'ResNet':

            for (w5, w6) in zip(model.layer5.parameters(),
                                model.layer6.parameters()):
                if W5 is None and W6 is None:
                    W5 = w5.view(-1)
                    W6 = w6.view(-1)
                else:
                    W5 = torch.cat((W5, w5.view(-1)), 0)
                    W6 = torch.cat((W6, w6.view(-1)), 0)

        loss_weight = (torch.matmul(W5, W6) /
                       (torch.norm(W5) * torch.norm(W6)) + 1
                       )  # +1 is for a positive loss
        loss_weight = loss_weight * Lambda_weight * damping * 2
        loss_weight.backward()

        # ======================================================================================
        # train D
        # ======================================================================================

        # Bring back Grads in D
        for param in model_D.parameters():
            param.requires_grad = True

        # Train with Source
        pred_source1 = pred_source1.detach()
        pred_source2 = pred_source2.detach()

        D_out_s = interp_source(
            model_D(F.softmax(pred_source1 + pred_source2, dim=1)))

        loss_D_s = bce_loss(
            D_out_s,
            torch.FloatTensor(
                D_out_s.data.size()).fill_(source_label).to(device))

        loss_D_s.backward()

        # Train with Target
        pred_target1 = pred_target1.detach()
        pred_target2 = pred_target2.detach()
        weight_map = weight_map.detach()

        D_out_t = interp_target(
            model_D(F.softmax(pred_target1 + pred_target2, dim=1)))

        # Adaptive Adversarial Loss
        if (i_iter > PREHEAT_STEPS):
            loss_D_t = weighted_bce_loss(
                D_out_t,
                torch.FloatTensor(
                    D_out_t.data.size()).fill_(target_label).to(device),
                weight_map, Epsilon, Lambda_local)
        else:
            loss_D_t = bce_loss(
                D_out_t,
                torch.FloatTensor(
                    D_out_t.data.size()).fill_(target_label).to(device))

        loss_D_t.backward()

        optimizer.step()
        optimizer_D.step()

        if (i_iter) % 10 == 0:
            log_message(
                'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_adv = {3:.4f}, loss_weight = {4:.4f}, loss_D_s = {5:.4f} loss_D_t = {6:.4f}'
                .format(i_iter, args.num_steps, loss_seg, loss_adv,
                        loss_weight, loss_D_s, loss_D_t), log_file)

        if (i_iter % args.save_pred_every == 0
                and i_iter != 0) or i_iter == args.num_steps - 1:
            i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1  # for last iter
            print('saving weights...')
            torch.save(model.state_dict(),
                       osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(snapshot_dir, 'GTA5_' + str(i_iter) + '_D.pth'))

    end = time.time()
    log_message(
        'Total training time: {} days, {} hours, {} min, {} sec '.format(
            int((end - start) / 86400), int((end - start) / 3600),
            int((end - start) / 60 % 60), int((end - start) % 60)), log_file)
    print('### Experiment: ' + args.experiment + ' finished ###')