Exemplo n.º 1
0
def draw_in_tensorboard(writer, images, label_trg, i_iter, pred_main, pred_main_swarp, num_classes, type_):
    grid_image = make_grid(images[:3].clone().cpu().data, 3, normalize=True)
    writer.add_image(f'Image - {type_}', grid_image, i_iter)

    pred_main_cat = torch.cat((pred_main, pred_main_swarp), dim=-1)

    grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray(
        np.argmax(F.softmax(pred_main_cat).cpu().data[0].numpy().transpose(1, 2, 0),
                  axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3,
                           normalize=False, range=(0, 255))
    writer.add_image(f'Prediction_main_swarp - {type_}', grid_image, i_iter)

    grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray(label_trg.cpu().squeeze(), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3,
                           normalize=False, range=(0, 255))
    writer.add_image(f'Labels_IAST - {type_}', grid_image, i_iter)
Exemplo n.º 2
0
def draw_in_tensorboard(writer, images, i_iter, pred_main, num_classes, type_):
    grid_image = make_grid(images[:3].clone().cpu().data, 3, normalize=True)
    writer.add_image(f'Image - {type_}', grid_image, i_iter)

    grid_image = make_grid(torch.from_numpy(
        np.array(
            colorize_mask(
                np.asarray(np.argmax(
                    F.softmax(pred_main).cpu().data[0].numpy().transpose(
                        1, 2, 0),
                    axis=2),
                           dtype=np.uint8)).convert('RGB')).transpose(2, 0,
                                                                      1)),
                           3,
                           normalize=False,
                           range=(0, 255))
    writer.add_image(f'Prediction - {type_}', grid_image, i_iter)

    output_sm = F.softmax(pred_main).cpu().data[0].numpy().transpose(1, 2, 0)
    output_ent = np.sum(-np.multiply(output_sm, np.log2(output_sm)),
                        axis=2,
                        keepdims=False)
    grid_image = make_grid(torch.from_numpy(output_ent),
                           3,
                           normalize=True,
                           range=(0, np.log2(num_classes)))
    writer.add_image(f'Entropy - {type_}', grid_image, i_iter)
Exemplo n.º 3
0
def draw_in_tensorboard(writer, images, images_aug, i_iter, pred_main, pred_main_aug, type_):
    grid_image = make_grid(images[:3].clone().cpu().data, 3, normalize=True)
    writer.add_image(f'Image - {type_}', grid_image, i_iter)

    grid_image = make_grid(images_aug[:3].clone().cpu().data, 3, normalize=True)
    writer.add_image(f'Image_aug - {type_}', grid_image, i_iter)

    grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray(
        np.argmax(F.softmax(pred_main).cpu().data[0].numpy().transpose(1, 2, 0),
                  axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3,
                           normalize=False, range=(0, 255))
    writer.add_image(f'Prediction - {type_}', grid_image, i_iter)

    grid_image = make_grid(torch.from_numpy(np.array(colorize_mask(np.asarray(
        np.argmax(F.softmax(pred_main_aug).cpu().data[0].numpy().transpose(1, 2, 0),
                  axis=2), dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)), 3,
                           normalize=False, range=(0, 255))
    writer.add_image(f'Prediction_aug - {type_}', grid_image, i_iter)
def eval_best(cfg, models, device, test_loader, interp, fixed_test_size,
              verbose):
    # -------------------------------------------------------- #
    # codes to initialize wandb for storing logs on its cloud
    wandb.init(project='FDA_integration_to_INTRA_DA')

    for key, val in cfg.items():
        wandb.config.update({key: val})

    # -------------------------------------------------------- #
    assert len(models) == 1, 'Not yet supported multi models in this mode'
    assert osp.exists(cfg.TEST.SNAPSHOT_DIR[0]), 'SNAPSHOT_DIR is not found'
    start_iter = cfg.TEST.SNAPSHOT_STEP
    step = cfg.TEST.SNAPSHOT_STEP
    max_iter = cfg.TEST.SNAPSHOT_MAXITER
    cache_path = osp.join(cfg.TEST.SNAPSHOT_DIR[0], 'all_res.pkl')
    if osp.exists(cache_path):
        cache_path = pickle_load(cache_path)
    else:
        all_res = {}
    cur_best_miou = -1
    cur_best_model = ''
    for i_iter in range(start_iter, max_iter + 1, step):
        restore_from = osp.join(cfg.TEST.SNAPSHOT_DIR[0],
                                f'model_{i_iter}.pth')
        if not osp.exists(restore_from):
            # continue
            if cfg.TEST.WAIT_MODEL:
                print('Waiting for model..!')
                while not osp.exists(restore_from):
                    time.sleep(5)
        print("Evaluating model", restore_from)
        if i_iter not in all_res.keys():
            load_checkpoint_for_evaluation(models[0], restore_from, device)
            # eval
            hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES))
            # for index, batch in enumerate(test_loader):
            #     image, _, _, name = batch
            test_iter = iter(test_loader)
            for index in tqdm(range(len(test_loader))):
                image, label, _, name = next(test_iter)
                if not fixed_test_size:
                    interp = nn.Upsample(size=(label.shape[1], label.shape[2]),
                                         mode='bilinear',
                                         align_corners=True)
                with torch.no_grad():
                    pred_main = models[0](image.cuda(device))[1]
                    output = interp(pred_main).cpu().data[0].numpy()
                    output = output.transpose(1, 2, 0)
                    output = np.argmax(output, axis=2)
                label = label.numpy()[0]
                hist += fast_hist(label.flatten(), output.flatten(),
                                  cfg.NUM_CLASSES)
                if verbose and index > 0 and index % 100 == 0:
                    print('{:d} / {:d}: {:0.2f}'.format(
                        index, len(test_loader),
                        100 * np.nanmean(per_class_iu(hist))))
            inters_over_union_classes = per_class_iu(hist)
            all_res[i_iter] = inters_over_union_classes
            pickle_dump(all_res, cache_path)

            # -------------------------------------------------------- #
            # save logs at weight and biases
            IoU_classes = {}
            for idx in range(cfg.NUM_CLASSES):
                IoU_classes[test_loader.dataset.class_names[idx]] = round(
                    inters_over_union_classes[idx] * 100, 2)
            wandb.log(IoU_classes, step=(i_iter))
            wandb.log(
                {
                    'mIoU19': round(
                        np.nanmean(inters_over_union_classes) * 100, 2)
                },
                step=(i_iter))
            wandb.log(
                {
                    'val_prediction':
                    wandb.Image(
                        colorize_mask(np.asarray(
                            output, dtype=np.uint8)).convert('RGB'))
                },
                step=(i_iter))
            # -------------------------------------------------------- #

        else:
            inters_over_union_classes = all_res[i_iter]
        computed_miou = round(np.nanmean(inters_over_union_classes) * 100, 2)
        if cur_best_miou < computed_miou:
            cur_best_miou = computed_miou
            cur_best_model = restore_from
        print('\tCurrent mIoU:', computed_miou)
        print('\tCurrent best model:', cur_best_model)
        print('\tCurrent best mIoU:', cur_best_miou)
        wandb.log({'best mIoU': cur_best_miou}, step=(i_iter))
        if verbose:
            display_stats(cfg, test_loader.dataset.class_names,
                          inters_over_union_classes)
Exemplo n.º 5
0
def eval_single(cfg, models,
                device, test_loader, interp,
                fixed_test_size, verbose):
    assert len(cfg.TEST.RESTORE_FROM) == len(models), 'Number of models are not matched'

    folder_path = cfg.TEST.RESTORE_FROM[0].split('/')[-2]
    folder_path = osp.join(result_root, folder_path, "eval_image")
    if not osp.exists(folder_path):
        os.makedirs(folder_path)

    for checkpoint, model in zip(cfg.TEST.RESTORE_FROM, models):
        load_checkpoint_for_evaluation(model, checkpoint, device)
    # eval
    hist = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES))
    for index, batch in tqdm(enumerate(test_loader)):
        image, label, _, name = batch
        if not fixed_test_size:
            interp = nn.Upsample(size=(label.shape[1], label.shape[2]), mode='bilinear', align_corners=True)
        with torch.no_grad():
            output = None
            for model, model_weight in zip(models, cfg.TEST.MODEL_WEIGHT):
                _, pred_main, pred_boundary = model(image.cuda(device))
                output_ = interp(pred_main).cpu().data[0].numpy()
                output_pred = interp(pred_main)
                if output is None:
                    output = model_weight * output_
                else:
                    output += model_weight * output_

                domain = name[0].split('/')[-2]
                save_path = folder_path + "/" + domain + "_" + name[0].split('/')[-1].split('.')[0]

                # segmentation prediction save
                save_image(torch.from_numpy(np.array(colorize_mask(
                    np.asarray(np.argmax(F.softmax(output_pred).cpu().data[0].numpy().transpose(1, 2, 0), axis=2),
                               dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)),
                           save_path+"_3_seg.png", 3, normalize=False, range=(0, 255))

                # boundary prediction save
                pred_boundary = F.interpolate(pred_boundary, label.shape[1:], mode='bilinear')
                save_image(pred_boundary.clone(), save_path+"_1_boundary.png", normalize=True)

                # red boundary visualize
                vis_red_boundary(save_path, pred_boundary.clone())

                # binary boundary prediction save
                save_image_binary(save_path, pred_boundary.clone(), threshold=0.5)

                # color label save
                save_image(torch.from_numpy(np.array(
                    colorize_mask(np.asarray(label.squeeze(0).numpy(),
                                             dtype=np.uint8)).convert('RGB')).transpose(2, 0, 1)),
                           save_path+"_2_label.png", 3, normalize=False)


            assert output is not None, 'Output is None'
            output = output.transpose(1, 2, 0)
            output = np.argmax(output, axis=2)
        label = label.numpy()[0]
        hist += fast_hist(label.flatten(), output.flatten(), cfg.NUM_CLASSES)
    inters_over_union_classes = per_class_iu(hist)
    print(f'mIoU = \t{round(np.nanmean(inters_over_union_classes) * 100, 2)}')
    if verbose:
        display_stats(cfg, test_loader.dataset.class_names, inters_over_union_classes)
Exemplo n.º 6
0
def train_advent(model, trainloader, targetloader, cfg, args):
    ''' UDA training with advent
    '''
    # Create the model and start the training.
    # pdb.set_trace()
    input_size_source = cfg.TRAIN.INPUT_SIZE_SOURCE
    input_size_target = cfg.TRAIN.INPUT_SIZE_TARGET
    SRC_IMG_MEAN = np.asarray(cfg.TRAIN.IMG_MEAN, dtype=np.float32)
    SRC_IMG_MEAN = torch.reshape(torch.from_numpy(SRC_IMG_MEAN), (1, 3, 1, 1))

    device = cfg.GPU_ID
    num_classes = cfg.NUM_CLASSES
    viz_tensorboard = os.path.exists(cfg.TRAIN.TENSORBOARD_LOGDIR)
    if viz_tensorboard:
        writer = SummaryWriter(log_dir=cfg.TRAIN.TENSORBOARD_LOGDIR)
        # -------------------------------------------------------- #
        # codes to initialize wandb for storing logs on its cloud
        wandb.init(project='FDA_integration_to_INTRA_DA')
        wandb.config.update(args)

        for key, val in cfg.items():
            wandb.config.update({key: val})

        wandb.watch(model)
        # -------------------------------------------------------- #

    # SEGMNETATION NETWORK
    model.train()
    model.to(device)
    cudnn.benchmark = True
    cudnn.enabled = True

    # DISCRIMINATOR NETWORK
    # feature-level
    d_aux = get_fc_discriminator(num_classes=num_classes)
    d_aux.train()
    d_aux.to(device)
    # restore_from = cfg.TRAIN.RESTORE_FROM_aux
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_aux, restore_from, device)

    # seg maps, i.e. output, level
    d_main = get_fc_discriminator(num_classes=num_classes)
    d_main.train()
    d_main.to(device)

    # restore_from = cfg.TRAIN.RESTORE_FROM_main
    # print("Load Discriminator:", restore_from)
    # load_checkpoint_for_evaluation(d_main, restore_from, device)

    # OPTIMIZERS
    # segnet's optimizer
    optimizer = optim.SGD(model.optim_parameters(cfg.TRAIN.LEARNING_RATE),
                          lr=cfg.TRAIN.LEARNING_RATE,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    # discriminators' optimizers
    optimizer_d_aux = optim.Adam(d_aux.parameters(),
                                 lr=cfg.TRAIN.LEARNING_RATE_D,
                                 betas=(0.9, 0.99))
    optimizer_d_main = optim.Adam(d_main.parameters(),
                                  lr=cfg.TRAIN.LEARNING_RATE_D,
                                  betas=(0.9, 0.99))

    # interpolate output segmaps
    interp = nn.Upsample(size=(input_size_source[1], input_size_source[0]),
                         mode='bilinear',
                         align_corners=True)
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear',
                                align_corners=True)

    # labels for adversarial training
    source_label = 0
    target_label = 1
    trainloader_iter = enumerate(trainloader)
    targetloader_iter = enumerate(targetloader)
    for i_iter in tqdm(range(cfg.TRAIN.EARLY_STOP + 1)):

        # reset optimizers
        optimizer.zero_grad()
        optimizer_d_aux.zero_grad()
        optimizer_d_main.zero_grad()
        # adapt LR if needed
        adjust_learning_rate(optimizer, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_aux, i_iter, cfg)
        adjust_learning_rate_discriminator(optimizer_d_main, i_iter, cfg)

        # UDA Training
        # only train segnet. Don't accumulate grads in disciminators
        for param in d_aux.parameters():
            param.requires_grad = False
        for param in d_main.parameters():
            param.requires_grad = False

        _, batch = trainloader_iter.__next__()
        images_source, labels, _, _ = batch

        _, batch = targetloader_iter.__next__()
        images, _, _, _ = batch

        # ----------------------------------------------------------------#
        B, C, H, W = images_source.shape

        mean_images_source = SRC_IMG_MEAN.repeat(B, 1, H, W)
        mean_images = SRC_IMG_MEAN.repeat(B, 1, H, W)

        if args.FDA_mode == 'on':
            # normalize the source and target image
            images_source -= mean_images_source
            images -= mean_images

        elif args.FDA_mode == 'off':
            # Keep source and target images as they are
            # no need to perform normalization again since that has been done already in dataset class(GTA5, cityscapes) when args.FDA_mode = 'off'
            images_source = images_source
            images = images

        else:
            raise KeyError()
        # ----------------------------------------------------------------#

        # debug:
        # labels=labels.numpy()
        # from matplotlib import pyplot as plt
        # import numpy as np
        # plt.figure(1), plt.imshow(labels[0]), plt.ion(), plt.colorbar(), plt.show()

        # train on source
        pred_src_aux, pred_src_main = model(images_source.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = interp(pred_src_aux)
            loss_seg_src_aux = loss_calc(pred_src_aux, labels, device)
        else:
            loss_seg_src_aux = 0
        pred_src_main = interp(pred_src_main)
        loss_seg_src_main = loss_calc(pred_src_main, labels, device)
        # pdb.set_trace()
        loss = (cfg.TRAIN.LAMBDA_SEG_MAIN * loss_seg_src_main +
                cfg.TRAIN.LAMBDA_SEG_AUX * loss_seg_src_aux)
        loss.backward()

        # adversarial training ot fool the discriminator
        pred_trg_aux, pred_trg_main = model(images.cuda(device))
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = interp_target(pred_trg_aux)
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_adv_trg_aux = bce_loss(d_out_aux, source_label)
        else:
            loss_adv_trg_aux = 0
        pred_trg_main = interp_target(pred_trg_main)
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_adv_trg_main = bce_loss(d_out_main, source_label)
        loss = (cfg.TRAIN.LAMBDA_ADV_MAIN * loss_adv_trg_main +
                cfg.TRAIN.LAMBDA_ADV_AUX * loss_adv_trg_aux)
        loss = loss
        loss.backward()

        # Train discriminator networks
        # enable training mode on discriminator networks
        for param in d_aux.parameters():
            param.requires_grad = True
        for param in d_main.parameters():
            param.requires_grad = True
        # train with source
        if cfg.TRAIN.MULTI_LEVEL:
            pred_src_aux = pred_src_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_src_aux)))
            loss_d_aux = bce_loss(d_out_aux, source_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        pred_src_main = pred_src_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_src_main)))
        loss_d_main = bce_loss(d_out_main, source_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        # train with target
        if cfg.TRAIN.MULTI_LEVEL:
            pred_trg_aux = pred_trg_aux.detach()
            d_out_aux = d_aux(prob_2_entropy(F.softmax(pred_trg_aux)))
            loss_d_aux = bce_loss(d_out_aux, target_label)
            loss_d_aux = loss_d_aux / 2
            loss_d_aux.backward()
        else:
            loss_d_aux = 0
        pred_trg_main = pred_trg_main.detach()
        d_out_main = d_main(prob_2_entropy(F.softmax(pred_trg_main)))
        loss_d_main = bce_loss(d_out_main, target_label)
        loss_d_main = loss_d_main / 2
        loss_d_main.backward()

        optimizer.step()
        if cfg.TRAIN.MULTI_LEVEL:
            optimizer_d_aux.step()
        optimizer_d_main.step()

        if i_iter % cfg.TRAIN.SAVE_PRED_EVERY == 0 and i_iter != 0:
            print('taking snapshot ...')
            print('exp =', cfg.TRAIN.SNAPSHOT_DIR)
            snapshot_dir = Path(cfg.TRAIN.SNAPSHOT_DIR)
            torch.save(model.state_dict(),
                       snapshot_dir / f'model_{i_iter}.pth')
            torch.save(d_aux.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_aux.pth')
            torch.save(d_main.state_dict(),
                       snapshot_dir / f'model_{i_iter}_D_main.pth')
            if i_iter >= cfg.TRAIN.EARLY_STOP - 1:
                break
        sys.stdout.flush()

        # Visualize with tensorboard
        if viz_tensorboard:
            # ----------------------------------------------------------------#

            if i_iter % cfg.TRAIN.TENSORBOARD_VIZRATE == cfg.TRAIN.TENSORBOARD_VIZRATE - 1:
                current_losses = {
                    'loss_seg_src_aux': loss_seg_src_aux,
                    'loss_seg_src_main': loss_seg_src_main,
                    'loss_adv_trg_aux': loss_adv_trg_aux,
                    'loss_adv_trg_main': loss_adv_trg_main,
                    'loss_d_aux': loss_d_aux,
                    'loss_d_main': loss_d_main
                }
                print_losses(current_losses, i_iter)

                log_losses_tensorboard(writer, current_losses, i_iter)
                draw_in_tensorboard(writer, images + mean_images, i_iter,
                                    pred_trg_main, num_classes, 'T')
                draw_in_tensorboard(writer, images_source + mean_images_source,
                                    i_iter, pred_src_main, num_classes, 'S')

                wandb.log({'loss': current_losses}, step=(i_iter + 1))
                if i_iter % (cfg.TRAIN.TENSORBOARD_VIZRATE
                             == cfg.TRAIN.TENSORBOARD_VIZRATE
                             ) * 25 - 1:  # for every 2500 iteration
                    wandb.log(
                        {'source': wandb.Image(torch.flip(images_source+mean_images_source, [1]).cpu().data[0].numpy().transpose((1, 2, 0))), \
                         'target': wandb.Image(torch.flip(images+mean_images, [1]).cpu().data[0].numpy().transpose((1, 2, 0))),
                         'pesudo label': wandb.Image(np.asarray(colorize_mask(np.asarray(labels.cpu().data.numpy().transpose(1,2,0).reshape((512,1024)), dtype=np.uint8)).convert('RGB')) )},
                        step=(i_iter + 1))