Beispiel #1
0
def train_minent(model, trainloader, targetloader, cfg):
    ''' UDA training with minEnt
    '''
    # Create the model and start the training.
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear',
                                align_corners=True)

    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP)):

        # reset optimizers
        optimizer.zero_grad()

        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)

        # UDA Training
        # train on source
        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main
                + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training with minent
        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        pred_trg_aux = interp_target(pred_trg_aux)
        pred_trg_main = interp_target(pred_trg_main)
        pred_prob_trg_aux = F.softmax(pred_trg_aux)
        pred_prob_trg_main = F.softmax(pred_trg_main)

        loss_target_entp_aux = entropy_loss(pred_prob_trg_aux)
        loss_target_entp_main = entropy_loss(pred_prob_trg_main)
        loss = (cfg.TRAIN.LAMBDA_ENT_AUX * loss_target_entp_aux
                + cfg.TRAIN.LAMBDA_ENT_MAIN * loss_target_entp_main)
        loss.backward()
        optimizer.step()

        current_losses = {'loss_seg_src_aux': loss_seg_src_aux,
                          'loss_seg_src_main': loss_seg_src_main,
                          'loss_ent_aux': loss_target_entp_aux,
                          'loss_ent_main': loss_target_entp_main}

        print_losses(current_losses, i_iter)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       osp.join(cfg.TRAIN.SNAPSHOT_DIR, f'model_{i_iter}.pth'))
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            log_losses_tensorboard(writer, current_losses, i_iter)

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T')
                draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S')
Beispiel #2
0
def train_advent(model, trainloader, targetloader, cfg):
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    # pdb.set_trace()
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux.train()
    d_aux.to(device)
    # restore_from = cfg.TRAIN.RESTORE_FROM_aux
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_aux, restore_from, device)


    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # restore_from = cfg.TRAIN.RESTORE_FROM_main
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_main, restore_from, device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP)):

        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False
        # train on source 
        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch
        # debug:
        # labels=labels.numpy()
        # from matplotlib import pyplot as plt
        # import numpy as np
        # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show()
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        # pdb.set_trace()
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main
                + cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training ot fool the discriminator
        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = interp_target(pred_trg_aux)
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_adv_trg_aux = bce_loss(d_out_aux, source_label)
        else:
            loss_adv_trg_aux = 0
        pred_trg_main = interp_target(pred_trg_main)
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main
                + cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux)
        loss = loss
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux)))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main)))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        else:
            loss_d_aux = 0
        pred_trg_main = pred_trg_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        current_losses = {'loss_seg_src_aux': loss_seg_src_aux,
                          'loss_seg_src_main': loss_seg_src_main,
                          'loss_adv_trg_aux': loss_adv_trg_aux,
                          'loss_adv_trg_main': loss_adv_trg_main,
                          'loss_d_aux': loss_d_aux,
                          'loss_d_main': loss_d_main}
        print_losses(current_losses, i_iter)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(), snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(), snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            log_losses_tensorboard(writer, current_losses, i_iter)

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                draw_in_tensorboard(writer, images, i_iter, pred_trg_main, num_classes, 'T')
                draw_in_tensorboard(writer, images_source, i_iter, pred_src_main, num_classes, 'S')
def train_preview(model, source_loader, target_loader, cfg, comet_exp):
    # UDA TRAINING
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux.train()
    d_aux.to(device)

    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(),
                                 lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    times = deque([0], maxlen=100)
    model_times = deque([0], maxlen=100)

    source_loader_iter = enumerate(source_loader)
    target_loader_iter = enumerate(target_loader)

    cur_best_miou = -1
    cur_best_model = ''

    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):
        times.append(time())
        comet_exp.log_metric("i_iter", i_iter)

        comet_exp.log_metric("target_epoch", i_iter / len(target_loader))
        comet_exp.log_metric("source_epoch", i_iter / len(source_loader))
        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False
        # train on source
        try:
            _, batch_and_path = source_loader_iter.__next__()
        except StopIteration:
            source_loader_iter = enumerate(source_loader)
            _, batch_and_path = source_loader_iter.__next__()

        images_source, labels = batch_and_path['data']['x'], batch_and_path[
            'data']['m']
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training to fool the discriminator
        try:
            _, batch = target_loader_iter.__next__()
        except StopIteration:
            target_loader_iter = enumerate(target_loader)
            _, batch = target_loader_iter.__next__()

        images = batch['data']['x']
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = interp_target(pred_trg_aux)
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_adv_trg_aux = bce_loss(d_out_aux, source_label)
        else:
            loss_adv_trg_aux = 0
        pred_trg_main = interp_target(pred_trg_main)
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main +
                cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux)
        loss = loss
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux)))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main)))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        else:
            loss_d_aux = 0
        pred_trg_main = pred_trg_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        model_times.append(time() - times[-1])
        mod_times = np.mean(model_times)
        comet_exp.log_metric("model_time", mod_times)

        current_losses = {
            'loss_seg_src_aux': loss_seg_src_aux,
            'loss_seg_src_main': loss_seg_src_main,
            'loss_adv_trg_aux': loss_adv_trg_aux,
            'loss_adv_trg_main': loss_adv_trg_main,
            'loss_d_aux': loss_d_aux,
            'loss_d_main': loss_d_main
        }
        print_losses(current_losses, i_iter)
        current_losses_numDict = tesnorDict2numDict(current_losses)
        comet_exp.log_metrics(current_losses_numDict)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break

        if i_iter % cfg.TRAIN.SAVE_IMAGE_PRED == 0 and i_iter != 0 or i_iter == cfg.TRAIN.EARLY_STOP:
            print("Inferring test images in iteration {}...".format(i_iter))
            hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES))
            image, label = batch['data']['x'][0], batch['data']['m'][0]
            image = image[None, :, :, :]

            interp = nn.Upsample(size=(label.shape[1], label.shape[2]),
                                 mode='bilinear',
                                 align_corners=True)
            with torch.no_grad():
                pred_main = model(image.cuda(device))[1]
                output = interp(pred_main).cpu().data[0].numpy()
                output = output.transpose(1, 2, 0)
                output = np.argmax(output, axis=2)
            label0 = label.numpy()[0]
            hist += fast_hist(label0.flatten(), output.flatten(),
                              cfg.NUM_CLASSES)
            output = torch.tensor(output, dtype=torch.float32)
            output = output[None, :, :]
            output_RGB = output.repeat(3, 1, 1)

            if i_iter % 100 == 0:
                print('{:d} / {:d}: {:0.2f}'.format(
                    i_iter % len(target_loader), len(target_loader),
                    100 * np.nanmean(per_class_iu(hist))))
            inters_over_union_classes = per_class_iu(hist)
            computed_miou = round(
                np.nanmean(inters_over_union_classes) * 100, 2)
            if cur_best_miou < computed_miou:
                cur_best_miou = computed_miou
                cur_best_model = f'model_{i_iter}.pth'
            print('\tCurrent mIoU:', computed_miou)
            print('\tCurrent best model:', cur_best_model)
            print('\tCurrent best mIoU:', cur_best_miou)
            mious = {
                'Current mIoU': computed_miou,
                'Current best model': cur_best_model,
                'Current best mIoU': cur_best_miou
            }
            comet_exp.log_metrics(mious)
            image = image[0]  # change size from [1,x,y,z] to [x,y,z]
            save_images = []

            save_images.append(image)
            # Overlay mask:

            save_mask = (image - (image * label.repeat(3, 1, 1)) +
                         label.repeat(3, 1, 1))

            save_fake_mask = (image - (image * output_RGB) + output_RGB)
            save_images.append(save_mask)
            save_images.append(save_fake_mask)
            save_images.append(label.repeat(3, 1, 1))
            save_images.append(output_RGB)

            write_images(save_images,
                         i_iter,
                         comet_exp=comet_exp,
                         store_im=cfg.TEST.store_images)
Beispiel #4
0
def train_advent(model, trainloader, targetloader, cfg):
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    # d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux = get_fe_discriminator(num_classes=1024)
    # saved_state_dict_D1 = torch.load('C:\\Users\\Administrator\\OneDrive - University of Ottawa\\Python\\ADVENT-master\experiments\\snapshots\\GTA2Cityscapes_DeepLabv2_AdvEnt413\\model_125000_D_aux.pth')
    # d_aux.load_state_dict(saved_state_dict_D1)
    d_aux.train()
    d_aux.to(device)

    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    # saved_state_dict_D2 = torch.load('C:\\Users\\Administrator\\OneDrive - University of Ottawa\\Python\\ADVENT-master\\experiments\\snapshots\\GTA2Cityscapes_DeepLabv2_AdvEnt413\\model_125000_D_main.pth')
    # d_main.load_state_dict(saved_state_dict_D2)
    d_main.train()
    d_main.to(device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(),
                                 lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)
    # interp_aux = nn.Upsample(size=(128, 256), mode='bilinear', align_corners=True)   # H/4
    # interp_aux_source = nn.Upsample(size=(180, 320), mode='bilinear', align_corners=True)   # H/4

    weighted_bce_loss = WeightedBCEWithLogitsLoss()
    criterion_seg = nn.CrossEntropyLoss(ignore_index=255)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    Epsilon = 0.1
    Lambda_local = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):

        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        damping = (1 - i_iter / 100000)

        ### UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False
        # train on source
        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch
        pred_src_aux, pred_src_main = model(
            images_source.cuda(device)
        )  # H/8 multi-level outputs coming from both conv4 and conv5
        # pred_src_aux = interp_aux_source(pred_src_aux)  # H/4=1280/4
        loss_seg_src_aux = 0
        # if cfg.TRAIN.MULTI_LEVEL:
        #     pred_src_aux = interp(pred_src_aux)
        #     loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        #     # pred_src_aux = F.softmax(pred_src_aux1)
        # else:
        #     loss_seg_src_aux = 0

        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training ot fool the discriminator
        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch
        pred_trg_aux, pred_trg_main = model(
            images.cuda(device))  # H/8=120, H/8=129
        # pred_trg_aux = interp_aux(pred_trg_aux)  # H/4=256
        pred_trg_main_0 = interp_target(pred_trg_main)
        pred_trg_main = F.softmax(pred_trg_main_0)

        def toweight(x):

            x = x.cpu().data[0][0]
            x = preprocessing.scale(x)
            x = 1 / (1 + np.exp(-x))
            x = x * 1.5
            x = torch.tensor(x, dtype=torch.float32, device=device)

            return x

        if cfg.TRAIN.MULTI_LEVEL:
            # pred_trg_aux = F.softmax(interp_target(pred_trg_aux))
            # d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux))) # -p*log(p)
            d_out_aux = interp_target(d_aux(pred_trg_aux))  # H/8->H/8->H
            loss_adv_trg_aux = 0
            ones = torch.ones_like(d_out_aux)
            zero = torch.zeros_like(d_out_aux)

            # if (i_iter > 5000):
            #     pred_trg_aux_conf = 1.0 - torch.max(pred_trg_aux, 1)[0]
            #     weight_map_aux = torch.unsqueeze(pred_trg_aux_conf, dim=0)
            #     loss_adv_trg_aux = weighted_bce_loss(d_out_aux, Variable(torch.FloatTensor(d_out_aux.data.size()).fill_(source_label).to(device)),
            #                                          weight_map_aux, Epsilon , Lambda_local)
            # else:
            #     loss_adv_trg_aux = bce_loss(d_out_aux, source_label)

        else:
            loss_adv_trg_aux = 0

        # pred_trg_main = F.softmax(interp_target(pred_trg_main))  # H/8->H
        d_out_main = interp_target(d_main(pred_trg_main))  # H->H/8->H
        # loss_adv_trg_main = bce_loss(d_out_main, source_label)

        if (i_iter > 5000):

            maxpred, label = torch.max(pred_trg_main.detach(), dim=1)
            mask = (maxpred > 0.90)
            label = torch.where(
                mask, label,
                torch.ones(1).to(device, dtype=torch.long) * 255)
            loss_seg_trg_main = criterion_seg(pred_trg_main_0, label)
            # loss_seg_trg_main_.backward()

            pred_trg_main_conf = 1.0 - torch.max(pred_trg_main, 1)[0]
            fweight = toweight(d_out_aux)
            # pred_trg_main_conf = 1 - torch.max(pred_trg_main.detach(), 1)[0]
            # fweight = toweight(d_out_aux.detach())
            weight_map_main = pred_trg_main_conf * fweight
            weight_map_main = torch.where(weight_map_main > 1, ones,
                                          weight_map_main)
            weight_map_main = torch.where(weight_map_main < 0.05, zero,
                                          weight_map_main)

            # weight_map_main = torch.unsqueeze(weight_map_main, dim=0)
            loss_adv_trg_main = weighted_bce_loss(
                d_out_main,
                Variable(
                    torch.FloatTensor(d_out_main.data.size()).fill_(
                        source_label).to(device)), weight_map_main, Epsilon,
                Lambda_local)
        else:
            loss_adv_trg_main = bce_loss(d_out_main, source_label)
            loss_seg_trg_main = 0

        loss = cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main * damping
        loss.backward()

        ### Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            # d_out_aux = interp(d_aux(F.softmax(pred_src_aux))) # -plog(p)
            d_out_aux = interp(d_aux(pred_src_aux))  # H/8->H/8->H
            # d_out_aux = d_aux(prob_2_entropy(pred_src_aux))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = interp(d_main(F.softmax(pred_src_main)))  # H->H/8->H
        # d_out_main = d_main(prob_2_entropy(pred_src_main))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            # d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            d_out_aux = interp_target(d_aux(pred_trg_aux))  # H/8->H/8->H
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
            # if (i_iter > 5000):
            #     weight_map_aux = weight_map_aux.detach()
            #     loss_d_aux = weighted_bce_loss(d_out_aux, Variable(torch.FloatTensor(d_out_aux.data.size()).fill_(target_label).to(device)),
            #                                          weight_map_aux, Epsilon, Lambda_local)
            # else:
            #     loss_d_aux = bce_loss(d_out_aux, target_label)

        else:
            loss_d_aux = 0

        pred_trg_main = pred_trg_main.detach()
        d_out_main = interp_target(d_main(pred_trg_main))
        # loss_d_main = bce_loss(d_out_main, target_label)

        if (i_iter > 5000):
            pred_trg_main_conf = pred_trg_main_conf.detach()
            fweight = toweight(d_out_aux)
            # fweight = toweight(d_out_aux.detach())
            weight_map_main = pred_trg_main_conf * fweight
            weight_map_main = torch.where(weight_map_main > 1, ones,
                                          weight_map_main)
            # weight_map_main = torch.unsqueeze(weight_map_main, dim=0)
            loss_d_main = weighted_bce_loss(
                d_out_main,
                Variable(
                    torch.FloatTensor(d_out_main.data.size()).fill_(
                        target_label).to(device)), weight_map_main, Epsilon,
                Lambda_local)
        else:
            loss_d_main = bce_loss(d_out_main, target_label)

        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        current_losses = {
            'loss_seg_trg_main': loss_seg_trg_main,
            'loss_seg_src_main': loss_seg_src_main,
            'loss_adv_trg_aux': loss_adv_trg_aux,
            'loss_adv_trg_main': loss_adv_trg_main,
            'loss_d_aux': loss_d_aux,
            'loss_d_main': loss_d_main
        }
        print_losses(current_losses, i_iter)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            log_losses_tensorboard(writer, current_losses, i_iter)

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                draw_in_tensorboard(writer, images, i_iter, pred_trg_main,
                                    num_classes, 'T')
                draw_in_tensorboard(writer, images_source, i_iter,
                                    pred_src_main, num_classes, 'S')
Beispiel #5
0
def train_dada(model, trainloader, targetloader, cfg):
    """ UDA training with dada
    """
    # Create the model and start the training.
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(
        model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
        lr=cfg.TRAIN.LEARNING_RATE,
        momentum=cfg.TRAIN.MOMENTUM,
        weight_decay=cfg.TRAIN.WEIGHT_DECAY,
    )

    # discriminators' optimizers
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(
        size=(input_size_source[1], input_size_source[0]),
        mode="bilinear",
        align_corners=True,
    )
    interp_target = nn.Upsample(
        size=(input_size_target[1], input_size_target[0]),
        mode="bilinear",
        align_corners=True,
    )

    # labels for adversarial training
    source_label = 0
    target_label = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):
        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_main.zero_grad()
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_main.parameters():
            param.requires_grad = False
        # train on source
        _, batch = trainloader_iter.__next__()
        images_source, labels, depth, _, _ = batch
        _, pred_src_main, pred_depth_src_main = model(
            images_source.cuda(device))
        pred_src_main = interp(pred_src_main)
        pred_depth_src_main = interp(pred_depth_src_main)
        loss_depth_src_main = loss_calc_depth(pred_depth_src_main, depth,
                                              device)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_DEPTH_MAIN * loss_depth_src_main)
        loss.backward()

        # adversarial training ot fool the discriminator
        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch
        _, pred_trg_main, pred_depth_trg_main = model(images.cuda(device))
        pred_trg_main = interp_target(pred_trg_main)
        pred_depth_trg_main = interp_target(pred_depth_trg_main)
        d_out_main = d_main(
            prob_2_entropy(F.softmax(pred_trg_main)) * pred_depth_trg_main)
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        pred_src_main = pred_src_main.detach()
        pred_depth_src_main = pred_depth_src_main.detach()
        d_out_main = d_main(
            prob_2_entropy(F.softmax(pred_src_main)) * pred_depth_src_main)
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main
        loss_d_main.backward()

        # train with target
        pred_trg_main = pred_trg_main.detach()
        pred_depth_trg_main = pred_depth_trg_main.detach()
        d_out_main = d_main(
            prob_2_entropy(F.softmax(pred_trg_main)) * pred_depth_trg_main)
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main
        loss_d_main.backward()

        optimizer.step()
        optimizer_d_main.step()

        current_losses = {
            "loss_seg_src_main": loss_seg_src_main,
            "loss_depth_src_main": loss_depth_src_main,
            "loss_adv_trg_main": loss_adv_trg_main,
            "loss_d_main": loss_d_main,
        }
        print_losses(current_losses, i_iter)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print("taking snapshot ...")
            print("exp =", cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f"model_{i_iter}.pth")
            torch.save(d_main.state_dict(),
                       snapshot_dir / f"model_{i_iter}_D_main.pth")
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            log_losses_tensorboard(writer, current_losses, i_iter)

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                draw_in_tensorboard(writer, images, i_iter, pred_trg_main,
                                    num_classes, "T")
                draw_in_tensorboard(writer, images_source, i_iter,
                                    pred_src_main, num_classes, "S")
Beispiel #6
0
def train_advent(model, trainloader, targetloader, cfg, args):
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    # pdb.set_trace()
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32)
    SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1))

    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)
        # -------------------------------------------------------- #
        # codes to initialize wandb for storing logs on its cloud
        wandb.init(project='FDA_integration_to_INTRA_DA')
        wandb.config.update(args)

        for key, val in cfg.items():
            wandb.config.update({key: val})

        wandb.watch(model)
        # -------------------------------------------------------- #

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux.train()
    d_aux.to(device)
    # restore_from = cfg.TRAIN.RESTORE_FROM_aux
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_aux, restore_from, device)

    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # restore_from = cfg.TRAIN.RESTORE_FROM_main
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_main, restore_from, device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(),
                                 lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):

        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False

        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch

        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch

        # ----------------------------------------------------------------#
        B, C, H, W = images_source.shape

        mean_images_source = SRC_IMG_MEAN.repeat(B, 1, H, W)
        mean_images = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            # normalize the source and target image
            images_source -= mean_images_source
            images -= mean_images

        elif args.FDA_mode == 'off':
            # Keep source and target images as they are
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            images_source = images_source
            images = images

        else:
            raise KeyError()
        # ----------------------------------------------------------------#

        # debug:
        # labels=labels.numpy()
        # from matplotlib import pyplot as plt
        # import numpy as np
        # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show()

        # train on source
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        # pdb.set_trace()
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training ot fool the discriminator
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = interp_target(pred_trg_aux)
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_adv_trg_aux = bce_loss(d_out_aux, source_label)
        else:
            loss_adv_trg_aux = 0
        pred_trg_main = interp_target(pred_trg_main)
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main +
                cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux)
        loss = loss
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux)))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main)))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        else:
            loss_d_aux = 0
        pred_trg_main = pred_trg_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            # ----------------------------------------------------------------#

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                current_losses = {
                    'loss_seg_src_aux': loss_seg_src_aux,
                    'loss_seg_src_main': loss_seg_src_main,
                    'loss_adv_trg_aux': loss_adv_trg_aux,
                    'loss_adv_trg_main': loss_adv_trg_main,
                    'loss_d_aux': loss_d_aux,
                    'loss_d_main': loss_d_main
                }
                print_losses(current_losses, i_iter)

                log_losses_tensorboard(writer, current_losses, i_iter)
                draw_in_tensorboard(writer, images + mean_images, i_iter,
                                    pred_trg_main, num_classes, 'T')
                draw_in_tensorboard(writer, images_source + mean_images_source,
                                    i_iter, pred_src_main, num_classes, 'S')

                wandb.log({'loss': current_losses}, step=(i_iter + 1))
                if i_iter % (cfg.TRAIN.TENSORBOARD_VIZRATE
                             == cfg.TRAIN.TENSORBOARD_VIZRATE
                             ) * 25 - 1:  # for every 2500 iteration
                    wandb.log(
                        {'source': wandb.Image(torch.flip(images_source+mean_images_source, [1]).cpu().data[0].numpy().transpose((1, 2, 0))), \
                         'target': wandb.Image(torch.flip(images+mean_images, [1]).cpu().data[0].numpy().transpose((1, 2, 0))),
                         'pesudo label': wandb.Image(np.asarray(colorize_mask(np.asarray(labels.cpu().data.numpy().transpose(1,2,0).reshape((512,1024)), dtype=np.uint8)).convert('RGB')) )},
                        step=(i_iter + 1))
Beispiel #7
0
def train_self_domain_swarp(model, trainloader, targetloader, cfg):
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)

    # Model clone
    model_runner = copy.deepcopy(model)
    model_runner.eval()
    model_runner.to(device)

    # conv3x3_tgt = get_conv_abstract(cfg)
    # conv3x3_tgt.train()
    # conv3x3_tgt.to(device)

    # d_main = get_fc_discriminator(num_classes=num_classes)
    # d_main.train()
    # d_main.to(device)

    tgt_dict_tot = {}

    cudnn.benchmark = True
    cudnn.enabled = True

    # OPTIMIZERS
    # params = list(model.parameters()) + list(conv3x3_tgt.parameters())
    optimizer = optim.SGD(model.parameters(),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]), mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear',
                                align_corners=True)

    cls_thresh = torch.ones(num_classes).type(torch.float32)

    # optimizer_d_main = optim.Adam(d_main.parameters(), lr=cfg.TRAIN.LEARNING_RATE_D,
    #                               betas=(0.9, 0.99))

    # for round in range(3):

    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)

    source_label = 0
    target_label = 1

    tot_iter = len(targetloader)

    for i_iter in tqdm(range(tot_iter)):

        # reset optimizers
        optimizer.zero_grad()
        # optimizer_d_main.zero_grad()

        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        # adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # train on source
        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch
        pred_src_main, _ = model(images_source.cuda(device))
 
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        loss = cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main
        loss.backward()

        # adversarial training ot fool the discriminator
        _, batch = targetloader_iter.__next__()
        images, images_rev, _, _, name, name_next = batch

        pred_trg_main, feat_trg_main = model(images.cuda(device))

        pred_trg_main = interp_target(pred_trg_main)
        
        with torch.no_grad():
            pred_trg_main_run, feat_trg_main_run = model_runner(images.cuda(device))
            pred_trg_main_run = interp_target(pred_trg_main_run)            

        ##### Label generator for target #####
        label_trg, cls_thresh = label_generator(pred_trg_main_run, cls_thresh, cfg, i_iter, tot_iter)


        ##### CE loss for trg
        # MRKLD + Ign Region
        loss_seg_trg_main = reg_loss_calc_ign(pred_trg_main, label_trg, device)
        loss_tgt_seg = cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_trg_main

        ##### Domain swarping ####
        feat_tgt_swarped, tgt_dict_tot, tgt_label = DomainSwarping(feat_trg_main, label_trg, tgt_dict_tot, device)

        ignore_mask = tgt_label == 255

        feat_tgt_swarped = ~ignore_mask*feat_tgt_swarped + ignore_mask*feat_trg_main
        pred_tgt_swarped = model.classifier_(feat_tgt_swarped)
        pred_tgt_swarped = interp_target(pred_tgt_swarped)

        loss_seg_trg_swarped = reg_loss_calc_ign(pred_tgt_swarped, label_trg, device)
        loss_tgt_seg_swarped = cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_trg_swarped

        loss_tgt = loss_tgt_seg + loss_tgt_seg_swarped

        
        loss_tgt.backward()


        optimizer.step()

        current_losses = {'loss_seg_trg_main': loss_seg_trg_main,
                          'loss_seg_src_main': loss_seg_src_main,
                          'loss_seg_trg_swarped': loss_seg_trg_swarped
                          }

        print_losses(current_losses, i_iter)

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(), snapshot_dir / f'model_{i_iter}.pth')
            torch.save(model_runner.state_dict(), snapshot_dir / f'model_{i_iter}_run.pth')

            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            log_losses_tensorboard(writer, current_losses, i_iter)

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == 0:
                # draw_in_tensorboard_trg(writer, images, images_rev, label_trg, i_iter, pred_trg_main, pred_trg_main_rev, num_classes, 'T')
                draw_in_tensorboard(writer, images, label_trg, i_iter, pred_trg_main, pred_tgt_swarped, num_classes, 'T')