def ranking_target_w_discrim(model_seg, target_loader, output_path, cfg):

    train_lst_ordered = osp.join(output_path,'train_ent_full.txt')

    writer = SummaryWriter(log_dir=cfg.TEST.SNAPSHOT_DIR)

    num_classes = cfg.NUM_CLASSES
    device = cfg.GPU_ID    
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    targetloader_iter = enumerate(target_loader)

    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear',
                                align_corners=True)

    dis_dictionary = {}
    l1loss = torch.nn.L1Loss()
    with torch.no_grad():
        for i in tqdm(range(len(target_loader))):
        # for i in tqdm(range(30)):
            _, batch = targetloader_iter.__next__()
            images, image_aug, _, _, name, _ = batch
            pred_trg_main, _ = model_seg(images.cuda(device))

            # edge = CannyFilter(images)


            # pred_trg_main_aug, _ = model_seg(image_aug.cuda(device))

            pred_trg_main = interp_target(pred_trg_main)
            # pred_trg_main_aug = F.softmax(interp_target(pred_trg_main_aug),dim=1)

            # pred_trg_max, _ = torch.max(pred_trg_main,dim=1)
            # pred_trg_main_aug_ = pred_trg_main * pred_trg_main_aug
            # diff = l1loss(pred_trg_main, pred_trg_main_aug_)
            # diff = torch.mean(diff)

            # pred_trg_entropy = model_dis(F.softmax(pred_trg_main))
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            # pred_trg_entropy = abs(float(pred_trg_entropy.mean()) - 0.5)
            # import pdb
            # pdb.set_trace()

            pred_trg_entropy = float(pred_trg_entropy.mean())

            dis_dictionary[name[0]] = pred_trg_entropy

            # if i % 2 == 0:
            #     draw_in_tensorboard(writer, images, image_aug, i, pred_trg_main, pred_trg_main_aug_, 'S')
            # import pdb
            # pdb.set_trace()
        dis_dictionary = sorted(dis_dictionary.items(),key=f2)
        with open(train_lst_ordered, 'w') as f:
            for i in range(len(dis_dictionary)):
                f.write("%s\t%s\n" % (dis_dictionary[i][0], dis_dictionary[i][1]))
Beispiel #2
0
def train_advent(model, trainloader, targetloader, cfg):
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    # pdb.set_trace()
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux.train()
    d_aux.to(device)
    # restore_from = cfg.TRAIN.RESTORE_FROM_aux
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_aux, restore_from, device)


    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # restore_from = cfg.TRAIN.RESTORE_FROM_main
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_main, restore_from, device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP)):

        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False
        # train on source 
        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch
        # debug:
        # labels=labels.numpy()
        # from matplotlib import pyplot as plt
        # import numpy as np
        # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show()
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        # pdb.set_trace()
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main
                + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training ot fool the discriminator
        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = interp_target(pred_trg_aux)
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_adv_trg_aux = bce_loss(d_out_aux, source_label)
        else:
            loss_adv_trg_aux = 0
        pred_trg_main = interp_target(pred_trg_main)
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main
                + cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux)
        loss = loss
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux)))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main)))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        else:
            loss_d_aux = 0
        pred_trg_main = pred_trg_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        current_losses = {'loss_seg_src_aux': loss_seg_src_aux,
                          'loss_seg_src_main': loss_seg_src_main,
                          'loss_adv_trg_aux': loss_adv_trg_aux,
                          'loss_adv_trg_main': loss_adv_trg_main,
                          'loss_d_aux': loss_d_aux,
                          'loss_d_main': loss_d_main}
        print_losses(current_losses, i_iter)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(), snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(), snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            log_losses_tensorboard(writer, current_losses, i_iter)

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T')
                draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S')
Beispiel #3
0
def main(args):
    # load configuration file
    device = cfg.GPU_ID
    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)

    if not os.path.exists(save_dir % (args.FDA_mode, args.round)):
        os.mkdir(save_dir % (args.FDA_mode, args.round))
    # ----------------------------------------------------------------#
    args.LB = str(args.MBT)  # set args.LB = 'MBT'
    SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32)
    SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1))

    ###################### here, replace by restoring three different model#####################
    if args.round == 0:  # first round of SSL
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}'

    elif args.round > 0:  # when SSL round is higher than 0

        # SOURCE and TARGET are no longer GTA and Cityscape, but are easy and hard split
        cfg.SOURCE = 'CityscapesEasy'
        cfg.TARGET = 'CityscapesHard'
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}_THRESH_{str(thresholding)}_ROUND_{args.round - 1}'
    else:
        raise KeyError()

    ##########################################################################################################################################
    # ----------------------------------------------------------------#
    cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)

    # load model with parameters trained from Inter-domain adaptation
    model_gen = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TEST.MULTI_LEVEL)

    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                            f'model_{args.best_iter}.pth')

    print("Loading the generator:", restore_from)

    load_checkpoint_for_evaluation(model_gen, restore_from, device)

    # load data
    target_dataset = CityscapesDataSet(args=args,
                                       root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       max_iters=None,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=None)

    target_loader_iter = enumerate(target_loader)

    # upsampling layer
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # ---------------------------------------------------------------------------------------------------------------#

    # step 1. entropy-ranking: split the target dataset into easy and hard cases.

    entropy_list = []
    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch

        # normalize the image before fed into the trained model
        B, C, H, W = image.shape
        mean_image = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            image -= mean_image

        elif args.FDA_mode == 'off':
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            image = image

        else:
            raise KeyError()

        with torch.no_grad():
            _, pred_trg_main = model_gen(
                image.cuda(device))  # shape(pred_trg_main) = (1, 19, 65, 129)
            pred_trg_main = interp_target(
                pred_trg_main)  # shape(pred_trg_main) = (1, 19, 512, 1024)
            if args.normalize == True:
                normalizor = (11 -
                              len(find_rare_class(pred_trg_main))) / 11.0 + 0.5
            else:
                normalizor = 1
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            entropy_list.append(
                (name[0], pred_trg_entropy.mean().item() * normalizor))
            # colorize_save(pred_trg_main, name[0], args.FDA_mode)

    # split the enntropy_list into
    _, easy_split = cluster_subdomain(entropy_list, args, thresholding)

    # ---------------------------------------------------------------------------------------------------------------#

    # step2. apply thresholding(either top 66% or confidence score above 0.9) to easy-split target dataset and save them.

    predicted_label = np.zeros(
        (len(easy_split), 512,
         1024))  # (512, 1024) is the size of target output
    predicted_prob = np.zeros((len(easy_split), 512, 1024))
    image_name = []
    idx = 0

    target_loader_iter = enumerate(target_loader)

    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch

        if name[0] not in easy_split:  # only compute the images that belongs to easy-split
            continue

        # normalize the image before fed into the trained model
        B, C, H, W = image.shape
        mean_image = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            image -= mean_image

        elif args.FDA_mode == 'off':
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            image = image

        else:
            raise KeyError()

        with torch.no_grad():
            _, pred_trg_main = model_gen(
                image.cuda(device))  # shape(pred_trg_main) = (1, 19, 65, 129)
            pred_trg_main = F.softmax(interp_target(pred_trg_main), dim=1).cpu(
            ).data[0].numpy()  # shape(pred_trg_main) = (1, 19, 512, 1024)
            pred_trg_main = pred_trg_main.transpose(
                1, 2, 0)  # shape(pred_trg_main) = (512, 1024, 19)
            label, prob = np.argmax(pred_trg_main,
                                    axis=2), np.max(pred_trg_main, axis=2)
            predicted_label[idx] = label
            predicted_prob[idx] = prob
            image_name.append(name[0])
            idx += 1

    assert len(easy_split) == len(
        image_name)  # check whether all images in easy-split are processed

    # compute the threshold for each label
    thres = []
    for i in range(cfg.NUM_CLASSES):
        x = predicted_prob[predicted_label == i]
        if len(x) == 0:
            thres.append(0)
            continue
        x = np.sort(x)
        thres.append(
            x[np.int(np.round(len(x) * 0.66))]
        )  # thres contains the thresholding values by labels in corresponding entry:thres[label]
    print(thres)
    thres = np.array(thres)
    thres[thres > 0.9] = 0.9

    print(thres)
    colorize_save_with_thresholding(easy_split, thres, predicted_label,
                                    predicted_prob, image_name, args)
Beispiel #4
0
def main(args):

    # load configuration file
    device = cfg.GPU_ID
    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)

    if not os.path.exists(save_dir % (args.FDA_mode)):
        os.mkdir(save_dir % (args.FDA_mode))
    # ----------------------------------------------------------------#
    args.LB = str(args.LB).replace('.', '_')
    SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32)
    SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1))

    if args.round == 0:  # first round of SSL
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}'

    elif args.round > 0:  # when SSL round is higher than 0

        # SOURCE and TARGET are no longer GTA and Cityscape, but are easy and hard split
        cfg.SOURCE = 'CityscapesEasy'
        cfg.TARGET = 'CityscapesHard'
        cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}_THRESH_{str(thresholding)}_ROUND_{args.round - 1}'
    else:
        raise KeyError()

    #cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{args.FDA_mode}_LB_{args.LB}'
    # ----------------------------------------------------------------#
    cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)

    # load model with parameters trained from Inter-domain adaptation
    model_gen = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TEST.MULTI_LEVEL)

    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                            f'model_{args.best_iter}.pth')

    print("Loading the generator:", restore_from)

    load_checkpoint_for_evaluation(model_gen, restore_from, device)

    # load data
    target_dataset = CityscapesDataSet(args=args,
                                       root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       max_iters=None,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=None)

    target_loader_iter = enumerate(target_loader)

    # upsampling layer
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    entropy_list = []
    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch

        # ----------------------------------------------------------------#
        """
        normalize the image before fed into the trained model
        """
        B, C, H, W = image.shape
        mean_image = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            image -= mean_image

        elif args.FDA_mode == 'off':
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            image = image

        else:
            raise KeyError()
        # ----------------------------------------------------------------#

        with torch.no_grad():
            _, pred_trg_main = model_gen(image.cuda(device))
            pred_trg_main = interp_target(pred_trg_main)
            if args.normalize == True:
                normalizor = (11 -
                              len(find_rare_class(pred_trg_main))) / 11.0 + 0.5
            else:
                normalizor = 1
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            entropy_list.append(
                (name[0], pred_trg_entropy.mean().item() * normalizor))
            colorize_save(pred_trg_main, name[0], args)

    # split the enntropy_list into
    cluster_subdomain(entropy_list, args, thresholding)
Beispiel #5
0
def main(args):

    # load configuration file
    device = cfg.GPU_ID
    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)

    if not os.path.exists('./color_masks'):
        os.mkdir('./color_masks')

    cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}'
    cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)

    # load model with parameters trained from Inter-domain adaptation
    model_gen = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TEST.MULTI_LEVEL)

    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                            f'model_{args.best_iter}.pth')

    print("Loading the generator:", restore_from)

    load_checkpoint_for_evaluation(model_gen, restore_from, device)

    # load data
    target_dataset = CityscapesDataSet(root=cfg.DATA_DIRECTORY_TARGET,
                                       list_path=cfg.DATA_LIST_TARGET,
                                       set=cfg.TRAIN.SET_TARGET,
                                       info_path=cfg.TRAIN.INFO_TARGET,
                                       max_iters=None,
                                       crop_size=cfg.TRAIN.INPUT_SIZE_TARGET,
                                       mean=cfg.TRAIN.IMG_MEAN)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=None)

    target_loader_iter = enumerate(target_loader)

    # upsampling layer
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    entropy_list = []
    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch
        with torch.no_grad():
            _, pred_trg_main = model_gen(image.cuda(device))
            pred_trg_main = interp_target(pred_trg_main)
            if args.normalize == True:
                normalizor = (11 -
                              len(find_rare_class(pred_trg_main))) / 11.0 + 0.5
            else:
                normalizor = 1
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            entropy_list.append(
                (name[0], pred_trg_entropy.mean().item() * normalizor))
            colorize_save(pred_trg_main, name[0])

    # split the enntropy_list into
    cluster_subdomain(entropy_list, args.lambda1)
Beispiel #6
0
def main(args):

    # load configuration file
    device = cfg.GPU_ID
    assert args.cfg is not None, 'Missing cfg file'
    cfg_from_file(args.cfg)

    if not os.path.exists('./output2'):
        os.mkdir('./output2')

    cfg.EXP_NAME = f'{cfg.SOURCE}2{cfg.TARGET}_{cfg.TRAIN.MODEL}_{cfg.TRAIN.DA_METHOD}_{cfg.NUM_CLASSES}class_18_01'
    cfg.TEST.SNAPSHOT_DIR[0] = osp.join(cfg.EXP_ROOT_SNAPSHOT, cfg.EXP_NAME)

    # load model with parameters trained from Inter-domain adaptation
    model_gen = get_deeplab_v2(num_classes=cfg.NUM_CLASSES,
                               multi_level=cfg.TEST.MULTI_LEVEL)

    restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                            f'model_{args.best_iter}.pth')
    #restore_from = r'C:\semseg\IntraDA\ADVENT\experiments\snapshots\SimRunway2RealRunway_DeepLabv2_AdvEnt_5class_v2\model_120000.pth'
    print("Loading the generator:", restore_from)

    load_checkpoint_for_evaluation(model_gen, args.checkpoint, device)

    # load data
    target_dataset = RealRunwayDataSet(args.data_root,
                                       list_path=args.list_path + '/{}.txt',
                                       set=cfg.TEST.SET_TARGET,
                                       info_path=cfg.TEST.INFO_TARGET,
                                       crop_size=cfg.TEST.INPUT_SIZE_TARGET,
                                       mean=cfg.TEST.IMG_MEAN,
                                       labels_size=cfg.TEST.OUTPUT_SIZE_TARGET)

    target_loader = data.DataLoader(target_dataset,
                                    batch_size=cfg.TRAIN.BATCH_SIZE_TARGET,
                                    num_workers=cfg.NUM_WORKERS,
                                    shuffle=True,
                                    pin_memory=True,
                                    worker_init_fn=None)

    target_loader_iter = enumerate(target_loader)

    # upsampling layer
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    entropy_list = []
    for index in tqdm(range(len(target_loader))):
        _, batch = target_loader_iter.__next__()
        image, _, _, name = batch
        with torch.no_grad():
            _, pred_trg_main = model_gen(image.cuda(device))
            pred_trg_main = interp_target(pred_trg_main)
            if args.normalize == True:
                normalizor = (11 -
                              len(find_rare_class(pred_trg_main))) / 11.0 + 0.5
            else:
                normalizor = 1
            pred_trg_entropy = prob_2_entropy(F.softmax(pred_trg_main))
            entropy_list.append(
                (name[0], pred_trg_entropy.mean().item() * normalizor))
            colorize_save(pred_trg_main, name[0])
def train_preview(model, source_loader, target_loader, cfg, comet_exp):
    # UDA TRAINING
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux.train()
    d_aux.to(device)

    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(),
                                 lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    times = deque([0], maxlen=100)
    model_times = deque([0], maxlen=100)

    source_loader_iter = enumerate(source_loader)
    target_loader_iter = enumerate(target_loader)

    cur_best_miou = -1
    cur_best_model = ''

    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):
        times.append(time())
        comet_exp.log_metric("i_iter", i_iter)

        comet_exp.log_metric("target_epoch", i_iter / len(target_loader))
        comet_exp.log_metric("source_epoch", i_iter / len(source_loader))
        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False
        # train on source
        try:
            _, batch_and_path = source_loader_iter.__next__()
        except StopIteration:
            source_loader_iter = enumerate(source_loader)
            _, batch_and_path = source_loader_iter.__next__()

        images_source, labels = batch_and_path['data']['x'], batch_and_path[
            'data']['m']
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training to fool the discriminator
        try:
            _, batch = target_loader_iter.__next__()
        except StopIteration:
            target_loader_iter = enumerate(target_loader)
            _, batch = target_loader_iter.__next__()

        images = batch['data']['x']
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = interp_target(pred_trg_aux)
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_adv_trg_aux = bce_loss(d_out_aux, source_label)
        else:
            loss_adv_trg_aux = 0
        pred_trg_main = interp_target(pred_trg_main)
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main +
                cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux)
        loss = loss
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux)))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main)))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        else:
            loss_d_aux = 0
        pred_trg_main = pred_trg_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        model_times.append(time() - times[-1])
        mod_times = np.mean(model_times)
        comet_exp.log_metric("model_time", mod_times)

        current_losses = {
            'loss_seg_src_aux': loss_seg_src_aux,
            'loss_seg_src_main': loss_seg_src_main,
            'loss_adv_trg_aux': loss_adv_trg_aux,
            'loss_adv_trg_main': loss_adv_trg_main,
            'loss_d_aux': loss_d_aux,
            'loss_d_main': loss_d_main
        }
        print_losses(current_losses, i_iter)
        current_losses_numDict = tesnorDict2numDict(current_losses)
        comet_exp.log_metrics(current_losses_numDict)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break

        if i_iter % cfg.TRAIN.SAVE_IMAGE_PRED == 0 and i_iter != 0 or i_iter == cfg.TRAIN.EARLY_STOP:
            print("Inferring test images in iteration {}...".format(i_iter))
            hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES))
            image, label = batch['data']['x'][0], batch['data']['m'][0]
            image = image[None, :, :, :]

            interp = nn.Upsample(size=(label.shape[1], label.shape[2]),
                                 mode='bilinear',
                                 align_corners=True)
            with torch.no_grad():
                pred_main = model(image.cuda(device))[1]
                output = interp(pred_main).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.argmax(output, axis=2)
            label0 = label.numpy()[0]
            hist += fast_hist(label0.flatten(), output.flatten(),
                              cfg.NUM_CLASSES)
            output = torch.tensor(output, dtype=torch.float32)
            output = output[None, :, :]
            output_RGB = output.repeat(3, 1, 1)

            if i_iter % 100 == 0:
                print('{:d} / {:d}: {:0.2f}'.format(
                    i_iter % len(target_loader), len(target_loader),
                    100 * np.nanmean(per_class_iu(hist))))
            inters_over_union_classes = per_class_iu(hist)
            computed_miou = round(
                np.nanmean(inters_over_union_classes) * 100, 2)
            if cur_best_miou < computed_miou:
                cur_best_miou = computed_miou
                cur_best_model = f'model_{i_iter}.pth'
            print('\tCurrent mIoU:', computed_miou)
            print('\tCurrent best model:', cur_best_model)
            print('\tCurrent best mIoU:', cur_best_miou)
            mious = {
                'Current mIoU': computed_miou,
                'Current best model': cur_best_model,
                'Current best mIoU': cur_best_miou
            }
            comet_exp.log_metrics(mious)
            image = image[0]  # change size from [1,x,y,z] to [x,y,z]
            save_images = []

            save_images.append(image)
            # Overlay mask:

            save_mask = (image - (image * label.repeat(3, 1, 1)) +
                         label.repeat(3, 1, 1))

            save_fake_mask = (image - (image * output_RGB) + output_RGB)
            save_images.append(save_mask)
            save_images.append(save_fake_mask)
            save_images.append(label.repeat(3, 1, 1))
            save_images.append(output_RGB)

            write_images(save_images,
                         i_iter,
                         comet_exp=comet_exp,
                         store_im=cfg.TEST.store_images)
Beispiel #8
0
def train_dada(model, trainloader, targetloader, cfg):
    """ UDA training with dada
    """
    # Create the model and start the training.
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(
        model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
        lr=cfg.TRAIN.LEARNING_RATE,
        momentum=cfg.TRAIN.MOMENTUM,
        weight_decay=cfg.TRAIN.WEIGHT_DECAY,
    )

    # discriminators' optimizers
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(
        size=(input_size_source[1], input_size_source[0]),
        mode="bilinear",
        align_corners=True,
    )
    interp_target = nn.Upsample(
        size=(input_size_target[1], input_size_target[0]),
        mode="bilinear",
        align_corners=True,
    )

    # labels for adversarial training
    source_label = 0
    target_label = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):
        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_main.zero_grad()
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_main.parameters():
            param.requires_grad = False
        # train on source
        _, batch = trainloader_iter.__next__()
        images_source, labels, depth, _, _ = batch
        _, pred_src_main, pred_depth_src_main = model(
            images_source.cuda(device))
        pred_src_main = interp(pred_src_main)
        pred_depth_src_main = interp(pred_depth_src_main)
        loss_depth_src_main = loss_calc_depth(pred_depth_src_main, depth,
                                              device)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_DEPTH_MAIN * loss_depth_src_main)
        loss.backward()

        # adversarial training ot fool the discriminator
        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch
        _, pred_trg_main, pred_depth_trg_main = model(images.cuda(device))
        pred_trg_main = interp_target(pred_trg_main)
        pred_depth_trg_main = interp_target(pred_depth_trg_main)
        d_out_main = d_main(
            prob_2_entropy(F.softmax(pred_trg_main)) * pred_depth_trg_main)
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        pred_src_main = pred_src_main.detach()
        pred_depth_src_main = pred_depth_src_main.detach()
        d_out_main = d_main(
            prob_2_entropy(F.softmax(pred_src_main)) * pred_depth_src_main)
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main
        loss_d_main.backward()

        # train with target
        pred_trg_main = pred_trg_main.detach()
        pred_depth_trg_main = pred_depth_trg_main.detach()
        d_out_main = d_main(
            prob_2_entropy(F.softmax(pred_trg_main)) * pred_depth_trg_main)
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main
        loss_d_main.backward()

        optimizer.step()
        optimizer_d_main.step()

        current_losses = {
            "loss_seg_src_main": loss_seg_src_main,
            "loss_depth_src_main": loss_depth_src_main,
            "loss_adv_trg_main": loss_adv_trg_main,
            "loss_d_main": loss_d_main,
        }
        print_losses(current_losses, i_iter)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print("taking snapshot ...")
            print("exp =", cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f"model_{i_iter}.pth")
            torch.save(d_main.state_dict(),
                       snapshot_dir / f"model_{i_iter}_D_main.pth")
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            log_losses_tensorboard(writer, current_losses, i_iter)

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                draw_in_tensorboard(writer, images, i_iter, pred_trg_main,
                                    num_classes, "T")
                draw_in_tensorboard(writer, images_source, i_iter,
                                    pred_src_main, num_classes, "S")
Beispiel #9
0
def train_advent(model, trainloader, targetloader, cfg, args):
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    # pdb.set_trace()
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32)
    SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1))

    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)
        # -------------------------------------------------------- #
        # codes to initialize wandb for storing logs on its cloud
        wandb.init(project='FDA_integration_to_INTRA_DA')
        wandb.config.update(args)

        for key, val in cfg.items():
            wandb.config.update({key: val})

        wandb.watch(model)
        # -------------------------------------------------------- #

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux.train()
    d_aux.to(device)
    # restore_from = cfg.TRAIN.RESTORE_FROM_aux
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_aux, restore_from, device)

    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # restore_from = cfg.TRAIN.RESTORE_FROM_main
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_main, restore_from, device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(),
                                 lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):

        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False

        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch

        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch

        # ----------------------------------------------------------------#
        B, C, H, W = images_source.shape

        mean_images_source = SRC_IMG_MEAN.repeat(B, 1, H, W)
        mean_images = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            # normalize the source and target image
            images_source -= mean_images_source
            images -= mean_images

        elif args.FDA_mode == 'off':
            # Keep source and target images as they are
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            images_source = images_source
            images = images

        else:
            raise KeyError()
        # ----------------------------------------------------------------#

        # debug:
        # labels=labels.numpy()
        # from matplotlib import pyplot as plt
        # import numpy as np
        # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show()

        # train on source
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        # pdb.set_trace()
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training ot fool the discriminator
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = interp_target(pred_trg_aux)
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_adv_trg_aux = bce_loss(d_out_aux, source_label)
        else:
            loss_adv_trg_aux = 0
        pred_trg_main = interp_target(pred_trg_main)
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main +
                cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux)
        loss = loss
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux)))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main)))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        else:
            loss_d_aux = 0
        pred_trg_main = pred_trg_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            # ----------------------------------------------------------------#

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                current_losses = {
                    'loss_seg_src_aux': loss_seg_src_aux,
                    'loss_seg_src_main': loss_seg_src_main,
                    'loss_adv_trg_aux': loss_adv_trg_aux,
                    'loss_adv_trg_main': loss_adv_trg_main,
                    'loss_d_aux': loss_d_aux,
                    'loss_d_main': loss_d_main
                }
                print_losses(current_losses, i_iter)

                log_losses_tensorboard(writer, current_losses, i_iter)
                draw_in_tensorboard(writer, images + mean_images, i_iter,
                                    pred_trg_main, num_classes, 'T')
                draw_in_tensorboard(writer, images_source + mean_images_source,
                                    i_iter, pred_src_main, num_classes, 'S')

                wandb.log({'loss': current_losses}, step=(i_iter + 1))
                if i_iter % (cfg.TRAIN.TENSORBOARD_VIZRATE
                             == cfg.TRAIN.TENSORBOARD_VIZRATE
                             ) * 25 - 1:  # for every 2500 iteration
                    wandb.log(
                        {'source': wandb.Image(torch.flip(images_source+mean_images_source, [1]).cpu().data[0].numpy().transpose((1, 2, 0))), \
                         'target': wandb.Image(torch.flip(images+mean_images, [1]).cpu().data[0].numpy().transpose((1, 2, 0))),
                         'pesudo label': wandb.Image(np.asarray(colorize_mask(np.asarray(labels.cpu().data.numpy().transpose(1,2,0).reshape((512,1024)), dtype=np.uint8)).convert('RGB')) )},
                        step=(i_iter + 1))