Esempio n. 1
0
def main():
    config = CfgNode.load_cfg(open(args.model_config_path, 'rb'))
    ckpt_path = args.model_weight_path

    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    model = get_model(config)
    model.load_state_dict(
        torch.load(ckpt_path, map_location='cpu')['state_dict'])
    model.eval()

    model_wrapper = ModelWraper(model, args.is_use_gpu,
                                config.MODEL.IS_SPLIT_LOSS)
    model_wrapper.eval()

    image_transforms = build_image_transforms()

    pre_image = imread(args.in_pre_path)
    post_image = imread(args.in_post_path)

    inputs_pre = image_transforms(pre_image)
    inputs_post = image_transforms(post_image)
    inputs_pre.unsqueeze_(0)
    inputs_post.unsqueeze_(0)

    loc, cls = model_wrapper(inputs_pre, inputs_post)

    if config.MODEL.IS_SPLIT_LOSS:
        loc, cls = argmax(loc, cls)
        loc = loc.detach().cpu().numpy().astype(np.uint8)[0]
        cls = cls.detach().cpu().numpy().astype(np.uint8)[0]
    else:
        loc = torch.argmax(loc, dim=1, keepdim=False)
        loc = loc.detach().cpu().numpy().astype(np.uint8)[0]
        cls = copy.deepcopy(loc)

    imsave(args.out_loc_path, loc)
    imsave(args.out_cls_path, cls)

    if args.is_vis:
        mask_map_img = np.zeros((cls.shape[0], cls.shape[1], 3),
                                dtype=np.uint8)
        mask_map_img[cls == 1] = (255, 255, 255)
        mask_map_img[cls == 2] = (229, 255, 50)
        mask_map_img[cls == 3] = (255, 159, 0)
        mask_map_img[cls == 4] = (255, 0, 0)
        compare_img = np.concatenate((pre_image, mask_map_img, post_image),
                                     axis=1)

        out_dir = os.path.dirname(args.out_loc_path)
        imsave(os.path.join(out_dir, 'compare_img.png'), compare_img)
Esempio n. 2
0
def main():
    if args.config_path:
        with open(args.config_path, 'rb') as fp:
            config = CfgNode.load_cfg(fp)
    else:
        config = None

    ckpts_save_dir = args.ckpt_save_dir
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    test_model = None
    max_epoch = config.TRAIN.NUM_EPOCHS
    if 'test' in args:
        test_model = args.test
    print('data folder: ', args.data_dir)

    torch.backends.cudnn.benchmark = True

    # WORLD_SIZE Generated by torch.distributed.launch.py
    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    is_distributed = num_gpus > 1
    if is_distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://",
        )

    model = get_model(config)
    model_loss = ModelLossWraper(model,
                                 config.TRAIN.CLASS_WEIGHTS,
                                 config.MODEL.IS_DISASTER_PRED,
                                 config.MODEL.IS_SPLIT_LOSS,
                                 )

    if is_distributed:
        model_loss = nn.SyncBatchNorm.convert_sync_batchnorm(model_loss)
        model_loss = nn.parallel.DistributedDataParallel(
            model_loss, device_ids=[args.local_rank], output_device=args.local_rank
        )

    trainset = XView2Dataset(args.data_dir, rgb_bgr='rgb',
                             preprocessing={'flip': True,
                                            'scale': config.TRAIN.MULTI_SCALE,
                                            'crop': config.TRAIN.CROP_SIZE,
                                            })

    if is_distributed:
        train_sampler = DistributedSampler(trainset)
    else:
        train_sampler = None

    trainset_loader = torch.utils.data.DataLoader(trainset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
                                                  shuffle=train_sampler is None, pin_memory=True, drop_last=True,
                                                  sampler=train_sampler, num_workers=num_gpus)

    model.train()

    lr_init = config.TRAIN.LR
    optimizer = torch.optim.SGD([{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': lr_init}],
                                lr=lr_init,
                                momentum=0.9,
                                weight_decay=0.,
                                nesterov=False,
                                )

    start_epoch = 0
    losses = AverageMeter()
    model.train()
    num_iters = max_epoch * len(trainset_loader)
    for epoch in range(start_epoch, max_epoch):
        if is_distributed:
            train_sampler.set_epoch(epoch)
        cur_iters = epoch * len(trainset_loader)

        for i, samples in enumerate(trainset_loader):
            lr = adjust_learning_rate(optimizer, lr_init, num_iters, i + cur_iters)

            inputs_pre = samples['pre_img']
            inputs_post = samples['post_img']
            target = samples['mask_img']
            disaster_target = samples['disaster']

            loss = model_loss(inputs_pre, inputs_post, target, disaster_target)

            loss_sum = torch.sum(loss).detach().cpu()
            if np.isnan(loss_sum) or np.isinf(loss_sum):
                print('check')
            losses.update(loss_sum, 4)  # batch size

            loss = torch.sum(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if args.local_rank == 0 and i % 10 == 0:
                logger.info('epoch: {0}\t'
                            'iter: {1}/{2}\t'
                            'lr: {3:.6f}\t'
                            'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                    epoch + 1, i + 1, len(trainset_loader), lr, loss=losses))

        if args.local_rank == 0:
            if (epoch + 1) % 50 == 0 and test_model is None:
                torch.save({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(ckpts_save_dir, 'hrnet_%s' % (epoch + 1)))
Esempio n. 3
0
def main():
    if args.config_path:
        with open(args.config_path, 'rb') as fp:
            config = CfgNode.load_cfg(fp)
    else:
        config = None

    ckpt_path = args.ckpt_path
    result_submit_dir = os.path.join(args.result_dir, 'submit/')
    result_compare_dir = os.path.join(args.result_dir, 'compare/')
    dataset_mode = 'test' if not args.is_train_data else 'train'
    imgs_dir = os.path.join(args.data_path, 'test/images/') if dataset_mode == 'test' \
        else os.path.join(args.data_path, 'tier3/images/')

    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    print('data folder: ', args.data_path)

    safe_mkdir(result_submit_dir)
    safe_mkdir(result_compare_dir)

    model = get_model(config)
    model.load_state_dict(
        torch.load(ckpt_path, map_location='cpu')['state_dict'])
    model.eval()
    model_wrapper = ModelWraper(model, args.is_use_gpu,
                                config.MODEL.IS_SPLIT_LOSS)
    # model_wrapper = nn.DataParallel(model_wrapper)
    model_wrapper.eval()

    testset = XView2Dataset(args.data_path,
                            rgb_bgr='rgb',
                            preprocessing={
                                'flip': False,
                                'scale': None,
                                'crop': None
                            },
                            mode=dataset_mode)
    testset_loader = torch.utils.data.DataLoader(testset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 pin_memory=False,
                                                 num_workers=1)

    for i, samples in enumerate(tqdm(testset_loader)):
        if dataset_mode == 'train' and i < 5520:
            continue
        inputs_pre = samples['pre_img']
        inputs_post = samples['post_img']
        image_ids = samples['image_id']

        loc, cls = model_wrapper(inputs_pre, inputs_post)

        if config.MODEL.IS_SPLIT_LOSS:
            loc, cls = argmax(loc, cls)
            loc = loc.detach().cpu().numpy().astype(np.uint8)
            cls = cls.detach().cpu().numpy().astype(np.uint8)
        else:
            loc = torch.argmax(loc, dim=1, keepdim=False)
            loc = loc.detach().cpu().numpy().astype(np.uint8)
            cls = copy.deepcopy(loc)

        for image_id, l, c in zip(image_ids, loc, cls):
            localization_filename = 'test_localization_%s_prediction.png' % image_id
            damage_filename = 'test_damage_%s_prediction.png' % image_id

            imsave(os.path.join(result_submit_dir, localization_filename), l)
            imsave(os.path.join(result_submit_dir, damage_filename), c)

            pre_filename = 'test_pre_%s.png' % image_id
            post_filename = 'test_post_%s.png' % image_id
            pre_image = imread(os.path.join(imgs_dir, pre_filename))
            post_image = imread(os.path.join(imgs_dir, post_filename))

            mask_map_img = np.zeros((c.shape[0], c.shape[1], 3),
                                    dtype=np.uint8)
            mask_map_img[c == 1] = (255, 255, 255)
            mask_map_img[c == 2] = (229, 255, 50)
            mask_map_img[c == 3] = (255, 159, 0)
            mask_map_img[c == 4] = (255, 0, 0)
            compare_img = np.concatenate((pre_image, mask_map_img, post_image),
                                         axis=1)

            compare_filename = 'test_%s.png' % image_id
            imsave(os.path.join(result_compare_dir, compare_filename),
                   compare_img)
Esempio n. 4
0
def main():
    if args.config_path:
        if args.config_path in CONFIG_TREATER:
            load_path = CONFIG_TREATER[args.config_path]
        elif args.config_path.endswith(".yaml"):
            load_path = args.config_path
        else:
            load_path = "experiments/" + CONFIG_TREATER[
                args.config_path] + ".yaml"
        with open(load_path, 'rb') as fp:
            config = CfgNode.load_cfg(fp)
    else:
        config = None

    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    test_model = None
    max_epoch = config.TRAIN.NUM_EPOCHS
    print('data folder: ', args.data_folder)
    torch.backends.cudnn.benchmark = True

    # WORLD_SIZE Generated by torch.distributed.launch.py
    #num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    #is_distributed = num_gpus > 1
    #if is_distributed:
    #    torch.cuda.set_device(args.local_rank)
    #    torch.distributed.init_process_group(
    #        backend="nccl", init_method="env://",
    #    )

    model = get_model(config)
    model_loss = ModelLossWraper(
        model,
        config.TRAIN.CLASS_WEIGHTS,
        config.MODEL.IS_DISASTER_PRED,
        config.MODEL.IS_SPLIT_LOSS,
    ).cuda()

    #if args.local_rank == 0:
    #from IPython import embed; embed()

    #if is_distributed:
    #    model_loss = nn.SyncBatchNorm.convert_sync_batchnorm(model_loss)
    #    model_loss = nn.parallel.DistributedDataParallel(
    #        model_loss#, device_ids=[args.local_rank], output_device=args.local_rank
    #    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        model_loss = nn.DataParallel(model_loss)

    model_loss.to(device)
    cpucount = multiprocessing.cpu_count()

    if config.mode.startswith("single"):
        trainset_loaders = {}
        loader_len = 0
        for disaster in disaster_list[config.mode[6:]]:
            trainset = XView2Dataset(args.data_folder,
                                     rgb_bgr='rgb',
                                     preprocessing={
                                         'flip': True,
                                         'scale': config.TRAIN.MULTI_SCALE,
                                         'crop': config.TRAIN.CROP_SIZE,
                                     },
                                     mode="singletrain",
                                     single_disaster=disaster)
            if len(trainset) > 0:
                train_sampler = None

                trainset_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
                    shuffle=train_sampler is None,
                    pin_memory=True,
                    drop_last=True,
                    sampler=train_sampler,
                    num_workers=cpucount if cpucount < 16 else cpucount // 3)

                trainset_loaders[disaster] = trainset_loader
                loader_len += len(trainset_loader)
                print("added disaster {} with {} samples".format(
                    disaster, len(trainset)))
            else:
                print("skipping disaster ", disaster)

    else:

        trainset = XView2Dataset(args.data_folder,
                                 rgb_bgr='rgb',
                                 preprocessing={
                                     'flip': True,
                                     'scale': config.TRAIN.MULTI_SCALE,
                                     'crop': config.TRAIN.CROP_SIZE,
                                 },
                                 mode=config.mode)

        #if is_distributed:
        #    train_sampler = DistributedSampler(trainset)
        #else:
        train_sampler = None

        trainset_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
            shuffle=train_sampler is None,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler,
            num_workers=multiprocessing.cpu_count())
        loader_len = len(trainset_loader)

    model.train()

    lr_init = config.TRAIN.LR
    optimizer = torch.optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad, model.parameters()),
            'lr': lr_init
        }],
        lr=lr_init,
        momentum=0.9,
        weight_decay=0.,
        nesterov=False,
    )

    num_iters = max_epoch * loader_len

    if config.SWA:
        swa_start = num_iters
        optimizer = SWA(
            optimizer,
            swa_start=swa_start,
            swa_freq=4 * loader_len,
            swa_lr=0.001
        )  #SWA(optimizer, swa_start = None, swa_freq = None, swa_lr = None)#
        #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0001, 0.05, step_size_up=1, step_size_down=2*len(trainset_loader)-1, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
        lr = 0.0001
        #model.load_state_dict(torch.load("ckpt/dual-hrnet/hrnet_450", map_location='cpu')['state_dict'])
        #print("weights loaded")
        max_epoch = max_epoch + 40

    start_epoch = 0
    losses = AverageMeter()
    model.train()
    cur_iters = 0 if start_epoch == 0 else None
    for epoch in range(start_epoch, max_epoch):

        if config.mode.startswith("single"):
            all_batches = []
            total_len = 0
            for disaster in sorted(list(trainset_loaders.keys())):
                all_batches += [
                    (disaster, idx)
                    for idx in range(len(trainset_loaders[disaster]))
                ]
                total_len += len(trainset_loaders[disaster].dataset)
            all_batches = random.sample(all_batches, len(all_batches))
            iterators = {
                disaster: iter(trainset_loaders[disaster])
                for disaster in trainset_loaders.keys()
            }
            if cur_iters is not None:
                cur_iters += len(all_batches)
            else:
                cur_iters = epoch * len(all_batches)

            for i, (disaster, idx) in enumerate(all_batches):
                lr = optimizer.param_groups[0]['lr']
                if not config.SWA or epoch < swa_start:
                    lr = adjust_learning_rate(optimizer, lr_init, num_iters,
                                              i + cur_iters)
                samples = next(iterators[disaster])
                inputs_pre = samples['pre_img'].to(device)
                inputs_post = samples['post_img'].to(device)
                target = samples['mask_img'].to(device)
                #disaster_target = samples['disaster'].to(device)

                loss = model_loss(inputs_pre, inputs_post,
                                  target)  #, disaster_target)

                loss_sum = torch.sum(loss).detach().cpu()
                if np.isnan(loss_sum) or np.isinf(loss_sum):
                    print('check')
                losses.update(loss_sum, 4)  # batch size

                loss = torch.sum(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                if args.local_rank == 0 and i % 10 == 0:
                    logger.info('epoch: {0}\t'
                                'iter: {1}/{2}\t'
                                'lr: {3:.6f}\t'
                                'loss: {loss.val:.4f} ({loss.ema:.4f})\t'
                                'disaster: {dis}'.format(epoch + 1,
                                                         i + 1,
                                                         len(all_batches),
                                                         lr,
                                                         loss=losses,
                                                         dis=disaster))

            del iterators

        else:
            cur_iters = epoch * len(trainset_loader)

            for i, samples in enumerate(trainset_loader):
                lr = optimizer.param_groups[0]['lr']
                if not config.SWA or epoch < swa_start:
                    lr = adjust_learning_rate(optimizer, lr_init, num_iters,
                                              i + cur_iters)

                inputs_pre = samples['pre_img'].to(device)
                inputs_post = samples['post_img'].to(device)
                target = samples['mask_img'].to(device)
                #disaster_target = samples['disaster'].to(device)

                loss = model_loss(inputs_pre, inputs_post,
                                  target)  #, disaster_target)

                loss_sum = torch.sum(loss).detach().cpu()
                if np.isnan(loss_sum) or np.isinf(loss_sum):
                    print('check')
                losses.update(loss_sum, 4)  # batch size

                loss = torch.sum(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                #if args.swa == "True":
                #scheduler.step()
                #if epoch%4 == 3 and i == len(trainset_loader)-2:
                #    optimizer.update_swa()

                if args.local_rank == 0 and i % 10 == 0:
                    logger.info('epoch: {0}\t'
                                'iter: {1}/{2}\t'
                                'lr: {3:.6f}\t'
                                'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                                    epoch + 1,
                                    i + 1,
                                    len(trainset_loader),
                                    lr,
                                    loss=losses))

        if args.local_rank == 0:
            if (epoch + 1) % 50 == 0 and test_model is None:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, os.path.join(ckpts_save_dir, 'hrnet_%s' % (epoch + 1)))
    if config.SWA:
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("preSWA")))
        optimizer.swap_swa_sgd()
        bn_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=2,
            shuffle=train_sampler is None,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler,
            num_workers=multiprocessing.cpu_count())
        bn_update(bn_loader, model, device='cuda')
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("SWA")))
Esempio n. 5
0
def main():
    if args.config_path:
        if args.config_path in CONFIG_TREATER:
            with open(
                    "experiments/" + CONFIG_TREATER[args.config_path] +
                    ".yaml", 'rb') as fp:
                config = CfgNode.load_cfg(fp)
        else:
            with open("experiments/" + args.config_path + ".yaml", 'rb') as fp:
                config = CfgNode.load_cfg(fp)
    else:
        config = None

    ckpt_path = args.weights
    if ckpt_path == "paper":
        ckpt_path = download_weights(args.config_path)

    result_submit_dir = "experiments/" + args.config_path + "/output/"  #args.result_dir #os.path.join(args.result_dir, 'submit/')
    #result_compare_dir = os.path.join(args.result_dir, 'compare/')

    #imgs_dir = os.path.join(args.data_path, 'test/images/')if dataset_mode == 'test' \
    #    else os.path.join(args.data_path, 'tier3/images/')

    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    print('data folder: ', args.data_folder)

    model = get_model(config)
    model.load_state_dict(
        torch.load(ckpt_path, map_location='cpu')['state_dict'])
    model.eval()
    model_wrapper = ModelWraper(model, not args.is_use_gpu,
                                config.MODEL.IS_SPLIT_LOSS)
    # model_wrapper = nn.DataParallel(model_wrapper)
    model_wrapper.eval()

    for dataset_mode in [config.mode + "test", config.mode + "hold"]:
        result_submit_dir = "experiments/" + args.config_path + "/output/" + dataset_mode
        os.makedirs(result_submit_dir, exist_ok=True)
        safe_mkdir(result_submit_dir)
        safe_mkdir(os.path.join(result_submit_dir, "predictions"))
        safe_mkdir(os.path.join(result_submit_dir, "targets"))
        #safe_mkdir(result_compare_dir)
        if dataset_mode.startswith("single"):
            with torch.no_grad():
                testset_loaders = {}
                for disaster in disaster_list[dataset_mode[6:]]:
                    testmode = "singletest" if "test" in dataset_mode else "singlehold"
                    testset = XView2Dataset(args.data_folder,
                                            rgb_bgr='rgb',
                                            preprocessing={
                                                'flip': False,
                                                'scale': None,
                                                'crop': None
                                            },
                                            mode=testmode,
                                            single_disaster=disaster)
                    if len(testset) > 0:
                        print("added disaster {} with {} samples".format(
                            disaster, len(testset)))
                        testset_loader = torch.utils.data.DataLoader(
                            testset,
                            batch_size=4,
                            shuffle=False,
                            pin_memory=False,
                            num_workers=4)

                        testset_loaders[disaster] = testset_loader
                    else:
                        print("skipped disaster ", disaster)

                for disaster in sorted(list(testset_loaders.keys())):
                    loader = testset_loaders[disaster]
                    if len(loader) == 0:
                        continue
                    print(disaster)
                    bn_update(loader, model, device='gpu')
                    #model_wrapper = ModelWraper(model, args.is_use_gpu, config.MODEL.IS_SPLIT_LOSS)
                    #model_wrapper.eval()
                    for i, samples in enumerate(tqdm(loader)):
                        if dataset_mode == 'train' and i < 5520:
                            continue
                        inputs_pre = samples['pre_img']
                        inputs_post = samples['post_img']
                        image_ids = samples['image_id']
                        if dataset_mode[6:] in [
                                "oodtest", "oodhold", "ood2test", "ood2hold",
                                "ood3test", "ood3hold", "guptatest",
                                "guptahold"
                        ]:
                            masks = samples['mask_img']

                        loc, cls = model_wrapper(inputs_pre, inputs_post)

                        if config.MODEL.IS_SPLIT_LOSS:
                            loc, cls = argmax(loc, cls)
                            loc = loc.detach().cpu().numpy().astype(np.uint8)
                            cls = cls.detach().cpu().numpy().astype(np.uint8)
                        else:
                            loc = torch.argmax(loc, dim=1, keepdim=False)
                            loc = loc.detach().cpu().numpy().astype(np.uint8)
                            cls = copy.deepcopy(loc)

                        for i, (image_id, l,
                                c) in enumerate(zip(image_ids, loc, cls)):
                            localization_filename = 'test_localization_%s_prediction.png' % image_id
                            damage_filename = 'test_damage_%s_prediction.png' % image_id

                            imsave(
                                os.path.join(result_submit_dir, "predictions",
                                             localization_filename), l)
                            imsave(
                                os.path.join(result_submit_dir, "predictions",
                                             damage_filename), c)

                            if dataset_mode[6:] in [
                                    "oodtest", "oodhold", "ood2test",
                                    "ood2hold", "ood3test", "ood3hold",
                                    "guptatest", "guptahold"
                            ]:
                                localization_filename = 'test_localization_%s_target.png' % image_id
                                damage_filename = 'test_damage_%s_target.png' % image_id

                                mask = masks[i]
                                mask[mask == 255] = 0
                                mask = mask.cpu().numpy().astype(np.uint8)

                                imsave(
                                    os.path.join(result_submit_dir, "targets",
                                                 localization_filename),
                                    (1 * (mask > 0)))
                                imsave(
                                    os.path.join(result_submit_dir, "targets",
                                                 damage_filename), mask)
                    #model_wrapper.model.cpu()

        else:
            testset = XView2Dataset(args.data_folder,
                                    rgb_bgr='rgb',
                                    preprocessing={
                                        'flip': False,
                                        'scale': None,
                                        'crop': None
                                    },
                                    mode=dataset_mode)
            testset_loader = torch.utils.data.DataLoader(testset,
                                                         batch_size=1,
                                                         shuffle=False,
                                                         pin_memory=False,
                                                         num_workers=1)

            for i, samples in enumerate(tqdm(testset_loader)):
                if i > 10:
                    break
                if dataset_mode == 'train' and i < 5520:
                    continue
                inputs_pre = samples['pre_img']
                inputs_post = samples['post_img']
                image_ids = samples['image_id']
                if dataset_mode in [
                        "oodtest", "oodhold", "ood2test", "ood2hold",
                        "ood3test", "ood3hold", "guptatest", "guptahold"
                ]:
                    masks = samples['mask_img']

                loc, cls = model_wrapper(inputs_pre, inputs_post)

                if config.MODEL.IS_SPLIT_LOSS:
                    loc, cls = argmax(loc, cls)
                    loc = loc.detach().cpu().numpy().astype(np.uint8)
                    cls = cls.detach().cpu().numpy().astype(np.uint8)
                else:
                    loc = torch.argmax(loc, dim=1, keepdim=False)
                    loc = loc.detach().cpu().numpy().astype(np.uint8)
                    cls = copy.deepcopy(loc)

                for i, (image_id, l, c) in enumerate(zip(image_ids, loc, cls)):
                    localization_filename = 'test_localization_%s_prediction.png' % image_id
                    damage_filename = 'test_damage_%s_prediction.png' % image_id

                    imsave(
                        os.path.join(result_submit_dir, "predictions",
                                     localization_filename), l)
                    imsave(
                        os.path.join(result_submit_dir, "predictions",
                                     damage_filename), c)

                    if dataset_mode in [
                            "oodtest", "oodhold", "ood2test", "ood2hold",
                            "ood3test", "ood3hold", "guptatest", "guptahold"
                    ]:
                        localization_filename = 'test_localization_%s_target.png' % image_id
                        damage_filename = 'test_damage_%s_target.png' % image_id

                        mask = masks[i]
                        mask[mask == 255] = 0
                        mask = mask.cpu().numpy().astype(np.uint8)

                        imsave(
                            os.path.join(result_submit_dir, "targets",
                                         localization_filename),
                            (1 * (mask > 0)))
                        imsave(
                            os.path.join(result_submit_dir, "targets",
                                         damage_filename), mask)

        base = result_submit_dir
        if True:
            for i, p in enumerate(os.listdir(os.path.join(base,
                                                          "predictions"))):
                if "damage" in p:
                    os.rename(
                        os.path.join(base, "targets",
                                     p.replace("prediction", "target")),
                        os.path.join(
                            base, "targets", "_".join(
                                p.split("_")[:2] + [str(i).zfill(6)] +
                                ["target.png"])))
                    os.rename(
                        os.path.join(base, "predictions", p),
                        os.path.join(
                            base, "predictions", "_".join(
                                p.split("_")[:2] + [str(i).zfill(6)] +
                                [p.split("_")[-1]])))
                    p = p.replace("damage", "localization")
                    os.rename(
                        os.path.join(base, "targets",
                                     p.replace("prediction", "target")),
                        os.path.join(
                            base, "targets", "_".join(
                                p.split("_")[:2] + [str(i).zfill(6)] +
                                ["target.png"])))
                    os.rename(
                        os.path.join(base, "predictions", p),
                        os.path.join(
                            base, "predictions", "_".join(
                                p.split("_")[:2] + [str(i).zfill(6)] +
                                [p.split("_")[-1]])))

        MetricsInstance = XviewMetrics(os.path.join(base, "predictions"),
                                       os.path.join(base, "targets"))

        MetricsInstance.compute_score(
            os.path.join(base, "predictions"), os.path.join(base, "targets"),
            os.path.join(base, "Results_{}.json".format(dataset_mode)))