Exemplo n.º 1
0
def run_test(cfg, model, local_rank, distributed):
    logger = logging.getLogger("FADA.tester")
    if local_rank==0:
        logger.info('>>>>>>>>>>>>>>>> Start Testing >>>>>>>>>>>>>>>>')
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    feature_extractor, classifier = model
    
    if distributed:
        feature_extractor, classifier = feature_extractor.module, classifier.module
    torch.cuda.empty_cache()  # TODO check if it helps
    dataset_name = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
        mkdir(output_folder)

    test_data = build_dataset(cfg, mode='test', is_source=False)
    if distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)
    else:
        test_sampler = None
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, sampler=test_sampler)
    feature_extractor.eval()
    classifier.eval()
    end = time.time()
    with torch.no_grad():
        for i, (x, y, _) in enumerate(test_loader):
            x = x.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).long()

            size = y.shape[-2:]
            pred = classifier(feature_extractor(x))
            pred = F.interpolate(pred, size=size, mode='bilinear', align_corners=True)
            
            output = pred.max(1)[1]
            intersection, union, target = intersectionAndUnionGPU(output, y, cfg.MODEL.NUM_CLASSES, cfg.INPUT.IGNORE_LABEL)
            if distributed:
                torch.distributed.all_reduce(intersection), torch.distributed.all_reduce(union), torch.distributed.all_reduce(target)
            intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
            intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)

            accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
            batch_time.update(time.time() - end)
            end = time.time()
    
    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if local_rank==0:
        logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
        for i in range(cfg.MODEL.NUM_CLASSES):
            logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i]))
Exemplo n.º 2
0
def get_threshold(cfg):
    logger = logging.getLogger("pseudo_label.trainer")
    logger.info("Start inference on target dataset and get threshold of each class")

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = strip_prefix_if_present(checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = strip_prefix_if_present(checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    feature_extractor.eval()
    classifier.eval()

    torch.cuda.empty_cache()
    tgt_train_data = build_dataset(cfg, mode='test', is_source=False)
    tgt_train_loader = torch.utils.data.DataLoader(tgt_train_data,
                                                   batch_size=cfg.SOLVER.BATCH_SIZE_VAL,
                                                   shuffle=False,
                                                   num_workers=4,
                                                   pin_memory=True,
                                                   sampler=None,
                                                   drop_last=False)

    cpseudo_label = PseudoLabel(cfg)
    for batch in tqdm(tgt_train_loader):
        x, _, name = batch
        tgt_input = x.cuda(non_blocking=True)
        tgt_size = tgt_input.shape[-2:]
        with torch.no_grad():
            output = classifier(feature_extractor(tgt_input))
        output = F.interpolate(output, size=tgt_size, mode='bilinear', align_corners=True)
        cpseudo_label.update_pseudo_label(output)
    thres_const = cpseudo_label.get_threshold_const(thred=0.9)
    cpseudo_label.save_results()

    return thres_const
Exemplo n.º 3
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("SDCA.trainer")

    # create network
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor = build_feature_extractor(cfg)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if local_rank == 0:
        print(classifier)

    # batch size: half for source and half for target
    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
        batch_size = int(cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor, device_ids=[local_rank], output_device=local_rank,
            find_unused_parameters=True, process_group=pg1
        )
        pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier, device_ids=[local_rank], output_device=local_rank,
            find_unused_parameters=True, process_group=pg2
        )
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    # init optimizer
    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    # load checkpoint
    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_weights = checkpoint['feature_extractor'] if distributed else strip_prefix_if_present(
            checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_weights)
        classifier_weights = checkpoint['classifier'] if distributed else strip_prefix_if_present(
            checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    # init data loader
    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)
    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None
    src_train_loader = torch.utils.data.DataLoader(src_train_data, batch_size=batch_size,
                                                   shuffle=(src_train_sampler is None), num_workers=4,
                                                   pin_memory=True, sampler=src_train_sampler, drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(tgt_train_data, batch_size=batch_size,
                                                   shuffle=(tgt_train_sampler is None), num_workers=4,
                                                   pin_memory=True, sampler=tgt_train_sampler, drop_last=True)

    # init loss
    ce_criterion = nn.CrossEntropyLoss(ignore_index=255)
    pcl_criterion = PixelContrastiveLoss(cfg)

    # load semantic distributions
    logger.info(">>>>>>>>>>>>>>>> Load semantic distributions >>>>>>>>>>>>>>>>")
    _, backbone_name = cfg.MODEL.NAME.split('_')
    feature_num = 2048 if backbone_name.startswith('resnet') else 1024
    feat_estimator = semantic_dist_estimator(feature_num=feature_num, cfg=cfg)
    if cfg.SOLVER.MULTI_LEVEL:
        out_estimator = semantic_dist_estimator(feature_num=cfg.MODEL.NUM_CLASSES, cfg=cfg)

    iteration = 0
    start_training_time = time.time()
    end = time.time()
    save_to_disk = local_rank == 0
    max_iters = cfg.SOLVER.MAX_ITER
    meters = MetricLogger(delimiter="  ")

    logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
    feature_extractor.train()
    classifier.train()

    for i, ((src_input, src_label, src_name), (tgt_input, _, _)) in enumerate(zip(src_train_loader, tgt_train_loader)):

        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        src_feat = feature_extractor(src_input)
        src_out = classifier(src_feat)
        tgt_feat = feature_extractor(tgt_input)
        tgt_out = classifier(tgt_feat)
        tgt_out_softmax = F.softmax(tgt_out, dim=1)

        # supervision loss
        src_pred = F.interpolate(src_out, size=src_size, mode='bilinear', align_corners=True)
        if cfg.SOLVER.LAMBDA_LOV > 0:
            src_pred_softmax = F.softmax(src_pred, dim=1)
            loss_lov = lovasz_softmax(src_pred_softmax, src_label, ignore=255)
            loss_sup = ce_criterion(src_pred, src_label) + cfg.SOLVER.LAMBDA_LOV * loss_lov
            meters.update(loss_lov=loss_lov.item())
        else:
            loss_sup = ce_criterion(src_pred, src_label)
        meters.update(loss_sup=loss_sup.item())

        # source mask: downsample the ground-truth label
        B, A, Hs, Ws = src_feat.size()
        src_mask = F.interpolate(src_label.unsqueeze(0).float(), size=(Hs, Ws), mode='nearest').squeeze(0).long()
        src_mask = src_mask.contiguous().view(B * Hs * Ws, )
        assert not src_mask.requires_grad
        # target mask: constant threshold
        _, _, Ht, Wt = tgt_feat.size()
        tgt_out_maxvalue, tgt_mask = torch.max(tgt_out_softmax, dim=1)
        for i in range(cfg.MODEL.NUM_CLASSES):
            tgt_mask[(tgt_out_maxvalue < cfg.SOLVER.DELTA) * (tgt_mask == i)] = 255
        tgt_mask = tgt_mask.contiguous().view(B * Ht * Wt, )
        assert not tgt_mask.requires_grad

        src_feat = src_feat.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, A)
        tgt_feat = tgt_feat.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, A)
        # update feature-level statistics
        feat_estimator.update(features=src_feat.detach(), labels=src_mask)

        # contrastive loss on both domains
        loss_feat = pcl_criterion(Mean=feat_estimator.Mean.detach(),
                                  CoVariance=feat_estimator.CoVariance.detach(),
                                  feat=src_feat,
                                  labels=src_mask) \
                    + pcl_criterion(Mean=feat_estimator.Mean.detach(),
                                  CoVariance=feat_estimator.CoVariance.detach(),
                                  feat=tgt_feat,
                                  labels=tgt_mask)
        meters.update(loss_feat=loss_feat.item())

        if cfg.SOLVER.MULTI_LEVEL:
            src_out = src_out.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, cfg.MODEL.NUM_CLASSES)
            tgt_out = tgt_out.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, cfg.MODEL.NUM_CLASSES)

            # update output-level statistics
            out_estimator.update(features=src_out.detach(), labels=src_mask)

            # the proposed contrastive loss on prediction map
            loss_out = pcl_criterion(Mean=out_estimator.Mean.detach(),
                                     CoVariance=out_estimator.CoVariance.detach(),
                                     feat=src_out,
                                     labels=src_mask) \
                       + pcl_criterion(Mean=out_estimator.Mean.detach(),
                                       CoVariance=out_estimator.CoVariance.detach(),
                                       feat=tgt_out,
                                       labels=tgt_mask)
            meters.update(loss_out=loss_out.item())

            loss = loss_sup \
                   + cfg.SOLVER.LAMBDA_FEAT * loss_feat \
                   + cfg.SOLVER.LAMBDA_OUT * loss_out
        else:
            loss = loss_sup + cfg.SOLVER.LAMBDA_FEAT * loss_feat

        loss.backward()

        optimizer_fea.step()
        optimizer_cls.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        iteration += 1
        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.02f} GB"
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0
                )
            )

        if (iteration == cfg.SOLVER.MAX_ITER or iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(cfg.OUTPUT_DIR, "model_iter{:06d}.pth".format(iteration))
            torch.save({'iteration': iteration,
                        'feature_extractor': feature_extractor.state_dict(),
                        'classifier': classifier.state_dict(),
                        'optimizer_fea': optimizer_fea.state_dict(),
                        'optimizer_cls': optimizer_cls.state_dict(),
                        }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / cfg.SOLVER.STOP_ITER
        )
    )

    return feature_extractor, classifier
Exemplo n.º 4
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("SelfSupervised.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if local_rank == 0:
        print(feature_extractor)
        print(classifier)

    batch_size = cfg.SOLVER.BATCH_SIZE
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(cfg.SOLVER.BATCH_SIZE /
                         torch.distributed.get_world_size())
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    src_train_data = build_dataset(cfg, mode='train', is_source=True)

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(src_train_data,
                                               batch_size=batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=4,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    ce_criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    max_iters = cfg.SOLVER.MAX_ITER
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    start_training_time = time.time()
    end = time.time()

    for i, (src_input, src_label, _) in enumerate(train_loader):
        data_time = time.time() - end
        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        size = src_label.shape[-2:]
        pred = classifier(feature_extractor(src_input), size)
        loss = ce_criterion(pred, src_label)
        loss.backward()

        optimizer_fea.step()
        optimizer_cls.step()
        meters.update(loss_seg=loss.item())
        iteration += 1

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)
        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.2f} GB",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 /
                    1024.0 / 1024.0,
                ))

        if (iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0
                or iteration == max_iters) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / cfg.SOLVER.STOP_ITER))

    return feature_extractor, classifier
Exemplo n.º 5
0
def semantic_dist_init(cfg):
    logger = logging.getLogger("semantic_dist_init.trainer")

    _, backbone_name = cfg.MODEL.NAME.split('_')
    feature_num = 2048 if backbone_name.startswith('resnet') else 1024
    feat_estimator = semantic_dist_estimator(feature_num=feature_num, cfg=cfg)
    out_estimator = semantic_dist_estimator(feature_num=cfg.MODEL.NUM_CLASSES,
                                            cfg=cfg)

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    torch.cuda.empty_cache()

    # load checkpoint
    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_extractor_weights = strip_prefix_if_present(
            checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_extractor_weights)
        classifier_weights = strip_prefix_if_present(checkpoint['classifier'],
                                                     'module.')
        classifier.load_state_dict(classifier_weights)

    src_train_data = build_dataset(cfg,
                                   mode='train',
                                   is_source=True,
                                   epochwise=True)
    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=cfg.SOLVER.BATCH_SIZE_VAL,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False)
    iteration = 0
    feature_extractor.eval()
    classifier.eval()
    end = time.time()
    start_time = time.time()
    max_iters = len(src_train_loader)
    meters = MetricLogger(delimiter="  ")
    logger.info(
        ">>>>>>>>>>>>>>>> Initialize semantic distributions >>>>>>>>>>>>>>>>")
    logger.info(max_iters)
    with torch.no_grad():
        for i, (src_input, src_label, _) in enumerate(src_train_loader):
            data_time = time.time() - end

            src_input = src_input.cuda(non_blocking=True)
            src_label = src_label.cuda(non_blocking=True).long()

            src_feat = feature_extractor(src_input)
            src_out = classifier(src_feat)
            B, N, Hs, Ws = src_feat.size()
            _, C, _, _ = src_out.size()

            # source mask: downsample the ground-truth label
            src_mask = F.interpolate(src_label.unsqueeze(0).float(),
                                     size=(Hs, Ws),
                                     mode='nearest').squeeze(0).long()
            src_mask = src_mask.contiguous().view(B * Hs * Ws, )

            # feature level
            src_feat = src_feat.permute(0, 2, 3,
                                        1).contiguous().view(B * Hs * Ws, N)
            feat_estimator.update(features=src_feat.detach().clone(),
                                  labels=src_mask)

            # output level
            src_out = src_out.permute(0, 2, 3,
                                      1).contiguous().view(B * Hs * Ws, C)
            out_estimator.update(features=src_out.detach().clone(),
                                 labels=src_mask)

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            iteration = iteration + 1
            eta_seconds = meters.time.global_avg * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % 20 == 0 or iteration == max_iters:
                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}", "iter: {iter}", "{meters}",
                        "max mem: {memory:.02f}"
                    ]).format(eta=eta_string,
                              iter=iteration,
                              meters=str(meters),
                              memory=torch.cuda.max_memory_allocated() /
                              1024.0 / 1024.0 / 1024.0))

            if iteration == max_iters:
                feat_estimator.save(name='feat_dist.pth')
                out_estimator.save(name='out_dist.pth')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=total_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_time / max_iters))
Exemplo n.º 6
0
def test(cfg, thres_const):
    logger = logging.getLogger("pseudo_label.tester")
    logger.info("Start testing")
    device = torch.device(cfg.MODEL.DEVICE)

    feature_extractor = build_feature_extractor(cfg)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_extractor_weights = strip_prefix_if_present(checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_extractor_weights)
        classifier_weights = strip_prefix_if_present(checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    feature_extractor.eval()
    classifier.eval()

    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    torch.cuda.empty_cache()
    assert cfg.DATASETS.TEST == 'cityscapes_train'
    dataset_name = cfg.DATASETS.TEST
    output_folder = '.'
    if cfg.OUTPUT_DIR:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "soft_labels", dataset_name)
        mkdir(output_folder)

    test_data = build_dataset(cfg, mode='test', is_source=False)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=cfg.SOLVER.BATCH_SIZE_VAL,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True,
                                              sampler=None)

    for index, batch in enumerate(test_loader):
        if index % 100 == 0:
            logger.info("{} processed".format(index))

        x, y, name = batch
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True).long()

        pred = inference(feature_extractor, classifier, x, y, flip=False)

        output = pred.max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(output, y, cfg.MODEL.NUM_CLASSES, cfg.INPUT.IGNORE_LABEL)
        intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

        # save the pseudo label
        pred = pred.cpu().numpy().squeeze()
        pred_max = np.max(pred, 0)
        pred_label = pred.argmax(0)
        for i in range(cfg.MODEL.NUM_CLASSES):
            pred_label[(pred_max < thres_const[i]) * (pred_label == i)] = 255
        mask = get_color_pallete(pred_label, "city")
        mask_filename = name[0] if len(name[0].split("/")) < 2 else name[0].split("/")[1]
        mask.save(os.path.join(output_folder, mask_filename))

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
    for i in range(cfg.MODEL.NUM_CLASSES):
        logger.info(
            '{} {} iou/accuracy: {:.4f}/{:.4f}.'.format(i, test_data.trainid2name[i], iou_class[i], accuracy_class[i]))
Exemplo n.º 7
0
def test_all(cfg, saveres):
    logger = logging.getLogger("BCDM.tester")
    logger.info("Start testing")
    device = torch.device(cfg.MODEL.DEVICE)

    feature_extractor = build_feature_extractor(cfg)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    classifier_2 = build_classifier(cfg)
    classifier_2.to(device)

    test_data = build_dataset(cfg, mode='test', is_source=False)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=cfg.TEST.BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True,
                                              sampler=None)

    test_stats = []
    best_iter = 0
    best_miou = 0

    for fname in sorted(os.listdir(cfg.resume)):
        if not fname.endswith('.pth'):
            continue
        logger.info("Loading checkpoint from {}".format(cfg.resume + '/' +
                                                        fname))
        checkpoint = torch.load(cfg.resume + '/' + fname)

        feature_extractor_weights = strip_prefix_if_present(
            checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_extractor_weights)
        classifier_weights = strip_prefix_if_present(checkpoint['classifier'],
                                                     'module.')
        classifier.load_state_dict(classifier_weights)
        classifier_weights_2 = strip_prefix_if_present(
            checkpoint['classifier_2'], 'module.')
        classifier_2.load_state_dict(classifier_weights_2)

        feature_extractor.eval()
        classifier.eval()
        classifier_2.eval()

        intersection_meter = AverageMeter()
        union_meter = AverageMeter()
        target_meter = AverageMeter()

        torch.cuda.empty_cache()
        dataset_name = cfg.DATASETS.TEST
        output_folder = '.'
        if cfg.OUTPUT_DIR:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                         dataset_name)
            mkdir(output_folder)
            if saveres:
                output_folder = os.path.join(cfg.OUTPUT_DIR, "inference",
                                             dataset_name,
                                             fname.replace('.pth', ''))
                mkdir(output_folder)

        for batch in tqdm(test_loader):
            x, y, name = batch
            x = x.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).long()
            pred = inference(feature_extractor,
                             classifier,
                             classifier_2,
                             x,
                             y,
                             flip=False)
            output = pred.max(1)[1]
            intersection, union, target = intersectionAndUnionGPU(
                output, y, cfg.MODEL.NUM_CLASSES, cfg.INPUT.IGNORE_LABEL)
            intersection, union, target = intersection.cpu().numpy(
            ), union.cpu().numpy(), target.cpu().numpy()
            intersection_meter.update(intersection), union_meter.update(
                union), target_meter.update(target)

            accuracy = sum(
                intersection_meter.val) / (sum(target_meter.val) + 1e-10)

            if saveres:
                pred = pred.cpu().numpy().squeeze().argmax(0)
                mask = get_color_pallete(pred, "city")
                mask_filename = name[0].split("/")[1]
                mask.save(os.path.join(output_folder, mask_filename))

        iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
        accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
        mIoU = np.mean(iou_class)
        mAcc = np.mean(accuracy_class)
        allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

        iter_num = int(re.findall(r'\d+', fname)[0])
        rec = {'iters': iter_num, 'mIoU': mIoU}
        logger.info(
            'Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(
                mIoU, mAcc, allAcc))
        for i in range(cfg.MODEL.NUM_CLASSES):
            rec[test_data.trainid2name[i]] = iou_class[i]
            logger.info('{} {} iou/accuracy: {:.4f}/{:.4f}.'.format(
                i, test_data.trainid2name[i], iou_class[i], accuracy_class[i]))
        test_stats.append(rec)

        if mIoU > best_miou:
            best_iter = iter_num
            best_miou = mIoU

    logger.info('Best result is got at iters {} with mIoU {:.4f}.'.format(
        best_iter, best_miou))
    with open(os.path.join(output_folder, 'test_results.csv'), 'w') as handle:
        for i, rec in enumerate(test_stats):
            if i == 0:
                handle.write(','.join(list(rec.keys())) + '\n')
            line = [str(rec[key]) for key in rec.keys()]
            handle.write(','.join(line) + '\n')
Exemplo n.º 8
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("FADA.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg, adv=True)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    model_D = build_adversarial_discriminator(cfg)
    model_D.to(device)

    if local_rank == 0:
        print(feature_extractor)
        print(model_D)

    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(
            cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        pg3 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        model_D = torch.nn.parallel.DistributedDataParallel(
            model_D,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg3)
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    optimizer_D = torch.optim.Adam(model_D.parameters(),
                                   lr=cfg.SOLVER.BASE_LR_D,
                                   betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    start_epoch = 0
    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)
        if "model_D" in checkpoint:
            logger.info("Loading model_D from {}".format(cfg.resume))
            model_D_weights = checkpoint[
                'model_D'] if distributed else strip_prefix_if_present(
                    checkpoint['model_D'], 'module.')
            model_D.load_state_dict(model_D_weights)
        # if "optimizer_fea" in checkpoint:
        #     logger.info("Loading optimizer_fea from {}".format(cfg.resume))
        #     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
        # if "optimizer_cls" in checkpoint:
        #     logger.info("Loading optimizer_cls from {}".format(cfg.resume))
        #     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
        # if "optimizer_D" in checkpoint:
        #     logger.info("Loading optimizer_D from {}".format(cfg.resume))
        #     optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        # if "iteration" in checkpoint:
        #     iteration = checkpoint['iteration']

    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)

    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(
            tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None

    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=batch_size,
        shuffle=(src_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=src_train_sampler,
        drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(
        tgt_train_data,
        batch_size=batch_size,
        shuffle=(tgt_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=tgt_train_sampler,
        drop_last=True)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
    bce_loss = torch.nn.BCELoss(reduction='none')

    max_iters = cfg.SOLVER.MAX_ITER
    source_label = 0
    target_label = 1
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    model_D.train()
    start_training_time = time.time()
    end = time.time()
    for i, ((src_input, src_label, src_name),
            (tgt_input, _,
             _)) in enumerate(zip(src_train_loader, tgt_train_loader)):
        #             torch.distributed.barrier()
        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        current_lr_D = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                            cfg.SOLVER.BASE_LR_D,
                                            iteration,
                                            max_iters,
                                            power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10
        for index in range(len(optimizer_D.param_groups)):
            optimizer_D.param_groups[index]['lr'] = current_lr_D


#       torch.distributed.barrier()

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        optimizer_D.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        tgt_size = tgt_input.shape[-2:]
        inp = torch.cat(
            [src_input, F.interpolate(tgt_input.detach(), src_size)])
        # try:
        src_fea = feature_extractor(inp)[:batch_size]

        src_pred = classifier(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        loss_seg = criterion(src_pred, src_label)
        loss_seg.backward()

        # torch.distributed.barrier()

        # generate soft labels
        src_soft_label = F.softmax(src_pred, dim=1).detach()
        src_soft_label[src_soft_label > 0.9] = 0.9

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred = tgt_pred.div(temperature)
        tgt_soft_label = F.softmax(tgt_pred, dim=1)

        tgt_soft_label = tgt_soft_label.detach()
        tgt_soft_label[tgt_soft_label > 0.9] = 0.9

        tgt_D_pred = model_D(tgt_fea, tgt_size)
        loss_adv_tgt = 0.001 * soft_label_cross_entropy(
            tgt_D_pred,
            torch.cat(
                (tgt_soft_label, torch.zeros_like(tgt_soft_label)), dim=1))
        loss_adv_tgt.backward()

        optimizer_fea.step()
        optimizer_cls.step()

        optimizer_D.zero_grad()
        # torch.distributed.barrier()

        src_D_pred = model_D(src_fea.detach(), src_size)
        loss_D_src = 0.5 * soft_label_cross_entropy(
            src_D_pred,
            torch.cat(
                (src_soft_label, torch.zeros_like(src_soft_label)), dim=1))
        loss_D_src.backward()

        tgt_D_pred = model_D(tgt_fea.detach(), tgt_size)
        loss_D_tgt = 0.5 * soft_label_cross_entropy(
            tgt_D_pred,
            torch.cat(
                (torch.zeros_like(tgt_soft_label), tgt_soft_label), dim=1))
        loss_D_tgt.backward()

        # torch.distributed.barrier()

        optimizer_D.step()

        meters.update(loss_seg=loss_seg.item())
        meters.update(loss_adv_tgt=loss_adv_tgt.item())
        meters.update(loss_D=(loss_D_src.item() + loss_D_tgt.item()))
        meters.update(loss_D_src=loss_D_src.item())
        meters.update(loss_D_tgt=loss_D_tgt.item())

        iteration = iteration + 1

        n = src_input.size(0)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if (iteration == cfg.SOLVER.MAX_ITER or iteration %
                cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'model_D': model_D.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict(),
                    'optimizer_D': optimizer_D.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (cfg.SOLVER.MAX_ITER)))

    return feature_extractor, classifier
Exemplo n.º 9
0
def test(cfg, saveres):
    logger = logging.getLogger("FADA.tester")
    logger.info("Start testing")
    device = torch.device(cfg.MODEL.DEVICE)

    feature_extractor = build_feature_extractor(cfg)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)
    
    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_extractor_weights = strip_prefix_if_present(checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_extractor_weights)
        classifier_weights = strip_prefix_if_present(checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    feature_extractor.eval()
    classifier.eval()
    
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()
    
    torch.cuda.empty_cache()  # TODO check if it helps
    dataset_name = cfg.DATASETS.TEST
    output_folder = '.'
    if cfg.OUTPUT_DIR:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
        mkdir(output_folder)

    test_data = build_dataset(cfg, mode='test', is_source=False)

    test_loader = torch.utils.data.DataLoader(test_data, batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, sampler=None)

    
    for batch in tqdm(test_loader):
        x, y, name = batch
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True).long()

        pred = inference(feature_extractor, classifier, x, y, flip=False)
        # pred = multi_scale_inference(feature_extractor, classifier, x, y, flip=True)

        output = pred.max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(output, y, cfg.MODEL.NUM_CLASSES, cfg.INPUT.IGNORE_LABEL)
        intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)

        accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)

        if saveres:
            pred = pred.cpu().numpy().squeeze()
            pred_max = np.max(pred, 0)
            pred = pred.argmax(0)
            # uncomment the following line when visualizing SYNTHIA->Cityscapes
            # pred = transform_color(pred)
            mask = get_color_pallete(pred, "city")
            mask_filename = name[0] if len(name[0].split("/"))<2 else name[0].split("/")[1]
            mask.save(os.path.join(output_folder, mask_filename))
    
    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))
    for i in range(cfg.MODEL.NUM_CLASSES):
        logger.info('{} {} iou/accuracy: {:.4f}/{:.4f}.'.format(i, test_data.trainid2name[i], iou_class[i], accuracy_class[i]))
Exemplo n.º 10
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("BCDM.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    classifier_2 = build_classifier(cfg)
    classifier_2.to(device)

    if local_rank == 0:
        print(feature_extractor)

    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(
            cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        pg3 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier_2 = torch.nn.parallel.DistributedDataParallel(
            classifier_2,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg3)

        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(list(classifier.parameters()) +
                                    list(classifier_2.parameters()),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)
        classifier_2_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier_2'], 'module.')
        classifier_2.load_state_dict(classifier_2_weights)
        # if "optimizer_fea" in checkpoint:
        #     logger.info("Loading optimizer_fea from {}".format(cfg.resume))
        #     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
        # if "optimizer_cls" in checkpoint:
        #     logger.info("Loading optimizer_cls from {}".format(cfg.resume))
        #     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
        # if "iteration" in checkpoint:
        #     iteration = checkpoint['iteration']

    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)

    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(
            tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None

    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=batch_size,
        shuffle=(src_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=src_train_sampler,
        drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(
        tgt_train_data,
        batch_size=batch_size,
        shuffle=(tgt_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=tgt_train_sampler,
        drop_last=True)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    max_iters = cfg.SOLVER.MAX_ITER
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    classifier_2.train()
    start_training_time = time.time()
    end = time.time()
    for i, ((src_input, src_label, src_name),
            (tgt_input, _,
             _)) in enumerate(zip(src_train_loader, tgt_train_loader)):
        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        # Step A: train on source (CE loss) and target (Ent loss)
        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        tgt_size = tgt_input.shape[-2:]

        src_fea = feature_extractor(src_input)
        src_pred = classifier(src_fea, src_size)
        src_pred_2 = classifier_2(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        src_pred_2 = src_pred_2.div(temperature)
        # source segmentation loss
        loss_seg = criterion(src_pred, src_label) + criterion(
            src_pred_2, src_label)

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
        tgt_pred = F.softmax(tgt_pred)
        tgt_pred_2 = F.softmax(tgt_pred_2)
        loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
        total_loss = loss_seg + cfg.SOLVER.ENT_LOSS * loss_ent
        total_loss.backward()
        # torch.distributed.barrier()
        optimizer_fea.step()
        optimizer_cls.step()

        # Step B: train bi-classifier to maximize loss_cdd
        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_fea = feature_extractor(src_input)
        src_pred = classifier(src_fea, src_size)
        src_pred_2 = classifier_2(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        src_pred_2 = src_pred_2.div(temperature)
        loss_seg = criterion(src_pred, src_label) + criterion(
            src_pred_2, src_label)

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
        tgt_pred = F.softmax(tgt_pred)
        tgt_pred_2 = F.softmax(tgt_pred_2)
        loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
        loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2)
        total_loss = loss_seg - cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent
        total_loss.backward()
        optimizer_cls.step()

        # Step C: train feature extractor to min loss_cdd
        for k in range(cfg.SOLVER.NUM_K):
            optimizer_fea.zero_grad()
            optimizer_cls.zero_grad()
            tgt_fea = feature_extractor(tgt_input)
            tgt_pred = classifier(tgt_fea, tgt_size)
            tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
            tgt_pred = F.softmax(tgt_pred)
            tgt_pred_2 = F.softmax(tgt_pred_2)
            loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
            loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2)
            total_loss = cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent
            total_loss.backward()
            optimizer_fea.zero_grad()

        meters.update(loss_seg=loss_seg.item())
        meters.update(loss_cdd=loss_cdd.item())
        meters.update(loss_ent=loss_ent.item())

        iteration = iteration + 1

        n = src_input.size(0)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if (iteration == cfg.SOLVER.MAX_ITER or iteration %
                cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'classifier_2': classifier_2.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / cfg.SOLVER.MAX_ITER))

    return feature_extractor, classifier, classifier_2
Exemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Semantic Segmentation Testing")
    parser.add_argument(
        "-cfg",
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    torch.backends.cudnn.benchmark = True

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("TSNE", save_dir, 0)
    logger.info(cfg)

    logger.info("Loaded configuration file {}".format(args.config_file))
    logger.info("Running with config:\n{}".format(cfg))

    logger = logging.getLogger("TSNE.tester")
    logger.info("Start")
    device = torch.device(cfg.MODEL.DEVICE)

    feature_extractor = build_feature_extractor(cfg)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_extractor_weights = strip_prefix_if_present(
            checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_extractor_weights)
        classifier_weights = strip_prefix_if_present(checkpoint['classifier'],
                                                     'module.')
        classifier.load_state_dict(classifier_weights)

    feature_extractor.eval()
    classifier.eval()

    torch.cuda.empty_cache()
    dataset_name = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
        mkdir(output_folder)

    test_data = build_dataset(cfg, mode='test', is_source=False)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=cfg.TSNE.BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True,
                                              sampler=None)
    for batch in tqdm(test_loader):
        x, y, name = batch
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True).long()

        pred, feat, outs = inference(feature_extractor, classifier, x, y)
        filename = name[0] if len(
            name[0].split("/")) < 2 else name[0].split("/")[1]

        # draw t-sne
        B, A, Ht, Wt = outs.size()
        tSNE_features = outs.permute(0, 2, 3,
                                     1).contiguous().view(B * Ht * Wt, A)
        tSNE_labels = F.interpolate(y.unsqueeze(0).float(),
                                    size=(Ht, Wt),
                                    mode='nearest').squeeze(0).long()
        tSNE_labels = tSNE_labels.contiguous().view(B * Ht * Wt, )

        mask = (tSNE_labels != cfg.INPUT.IGNORE_LABEL
                )  # remove IGNORE_LABEL pixels
        tSNE_labels = tSNE_labels[mask]
        tSNE_features = tSNE_features[mask]

        draw(tSNE_features=tSNE_features,
             tSNE_labels=tSNE_labels,
             name=filename,
             cfg=cfg)