Esempio n. 1
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("SDCA.trainer")

    # create network
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor = build_feature_extractor(cfg)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if local_rank == 0:
        print(classifier)

    # batch size: half for source and half for target
    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
        batch_size = int(cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor, device_ids=[local_rank], output_device=local_rank,
            find_unused_parameters=True, process_group=pg1
        )
        pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier, device_ids=[local_rank], output_device=local_rank,
            find_unused_parameters=True, process_group=pg2
        )
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    # init optimizer
    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    # load checkpoint
    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_weights = checkpoint['feature_extractor'] if distributed else strip_prefix_if_present(
            checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_weights)
        classifier_weights = checkpoint['classifier'] if distributed else strip_prefix_if_present(
            checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    # init data loader
    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)
    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None
    src_train_loader = torch.utils.data.DataLoader(src_train_data, batch_size=batch_size,
                                                   shuffle=(src_train_sampler is None), num_workers=4,
                                                   pin_memory=True, sampler=src_train_sampler, drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(tgt_train_data, batch_size=batch_size,
                                                   shuffle=(tgt_train_sampler is None), num_workers=4,
                                                   pin_memory=True, sampler=tgt_train_sampler, drop_last=True)

    # init loss
    ce_criterion = nn.CrossEntropyLoss(ignore_index=255)
    pcl_criterion = PixelContrastiveLoss(cfg)

    # load semantic distributions
    logger.info(">>>>>>>>>>>>>>>> Load semantic distributions >>>>>>>>>>>>>>>>")
    _, backbone_name = cfg.MODEL.NAME.split('_')
    feature_num = 2048 if backbone_name.startswith('resnet') else 1024
    feat_estimator = semantic_dist_estimator(feature_num=feature_num, cfg=cfg)
    if cfg.SOLVER.MULTI_LEVEL:
        out_estimator = semantic_dist_estimator(feature_num=cfg.MODEL.NUM_CLASSES, cfg=cfg)

    iteration = 0
    start_training_time = time.time()
    end = time.time()
    save_to_disk = local_rank == 0
    max_iters = cfg.SOLVER.MAX_ITER
    meters = MetricLogger(delimiter="  ")

    logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
    feature_extractor.train()
    classifier.train()

    for i, ((src_input, src_label, src_name), (tgt_input, _, _)) in enumerate(zip(src_train_loader, tgt_train_loader)):

        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        src_feat = feature_extractor(src_input)
        src_out = classifier(src_feat)
        tgt_feat = feature_extractor(tgt_input)
        tgt_out = classifier(tgt_feat)
        tgt_out_softmax = F.softmax(tgt_out, dim=1)

        # supervision loss
        src_pred = F.interpolate(src_out, size=src_size, mode='bilinear', align_corners=True)
        if cfg.SOLVER.LAMBDA_LOV > 0:
            src_pred_softmax = F.softmax(src_pred, dim=1)
            loss_lov = lovasz_softmax(src_pred_softmax, src_label, ignore=255)
            loss_sup = ce_criterion(src_pred, src_label) + cfg.SOLVER.LAMBDA_LOV * loss_lov
            meters.update(loss_lov=loss_lov.item())
        else:
            loss_sup = ce_criterion(src_pred, src_label)
        meters.update(loss_sup=loss_sup.item())

        # source mask: downsample the ground-truth label
        B, A, Hs, Ws = src_feat.size()
        src_mask = F.interpolate(src_label.unsqueeze(0).float(), size=(Hs, Ws), mode='nearest').squeeze(0).long()
        src_mask = src_mask.contiguous().view(B * Hs * Ws, )
        assert not src_mask.requires_grad
        # target mask: constant threshold
        _, _, Ht, Wt = tgt_feat.size()
        tgt_out_maxvalue, tgt_mask = torch.max(tgt_out_softmax, dim=1)
        for i in range(cfg.MODEL.NUM_CLASSES):
            tgt_mask[(tgt_out_maxvalue < cfg.SOLVER.DELTA) * (tgt_mask == i)] = 255
        tgt_mask = tgt_mask.contiguous().view(B * Ht * Wt, )
        assert not tgt_mask.requires_grad

        src_feat = src_feat.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, A)
        tgt_feat = tgt_feat.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, A)
        # update feature-level statistics
        feat_estimator.update(features=src_feat.detach(), labels=src_mask)

        # contrastive loss on both domains
        loss_feat = pcl_criterion(Mean=feat_estimator.Mean.detach(),
                                  CoVariance=feat_estimator.CoVariance.detach(),
                                  feat=src_feat,
                                  labels=src_mask) \
                    + pcl_criterion(Mean=feat_estimator.Mean.detach(),
                                  CoVariance=feat_estimator.CoVariance.detach(),
                                  feat=tgt_feat,
                                  labels=tgt_mask)
        meters.update(loss_feat=loss_feat.item())

        if cfg.SOLVER.MULTI_LEVEL:
            src_out = src_out.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, cfg.MODEL.NUM_CLASSES)
            tgt_out = tgt_out.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, cfg.MODEL.NUM_CLASSES)

            # update output-level statistics
            out_estimator.update(features=src_out.detach(), labels=src_mask)

            # the proposed contrastive loss on prediction map
            loss_out = pcl_criterion(Mean=out_estimator.Mean.detach(),
                                     CoVariance=out_estimator.CoVariance.detach(),
                                     feat=src_out,
                                     labels=src_mask) \
                       + pcl_criterion(Mean=out_estimator.Mean.detach(),
                                       CoVariance=out_estimator.CoVariance.detach(),
                                       feat=tgt_out,
                                       labels=tgt_mask)
            meters.update(loss_out=loss_out.item())

            loss = loss_sup \
                   + cfg.SOLVER.LAMBDA_FEAT * loss_feat \
                   + cfg.SOLVER.LAMBDA_OUT * loss_out
        else:
            loss = loss_sup + cfg.SOLVER.LAMBDA_FEAT * loss_feat

        loss.backward()

        optimizer_fea.step()
        optimizer_cls.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        iteration += 1
        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.02f} GB"
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0
                )
            )

        if (iteration == cfg.SOLVER.MAX_ITER or iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(cfg.OUTPUT_DIR, "model_iter{:06d}.pth".format(iteration))
            torch.save({'iteration': iteration,
                        'feature_extractor': feature_extractor.state_dict(),
                        'classifier': classifier.state_dict(),
                        'optimizer_fea': optimizer_fea.state_dict(),
                        'optimizer_cls': optimizer_cls.state_dict(),
                        }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / cfg.SOLVER.STOP_ITER
        )
    )

    return feature_extractor, classifier
Esempio n. 2
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("FADA.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg, adv=True)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    model_D = build_adversarial_discriminator(cfg)
    model_D.to(device)

    if local_rank == 0:
        print(feature_extractor)
        print(model_D)

    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(
            cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        pg3 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        model_D = torch.nn.parallel.DistributedDataParallel(
            model_D,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg3)
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    optimizer_D = torch.optim.Adam(model_D.parameters(),
                                   lr=cfg.SOLVER.BASE_LR_D,
                                   betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    start_epoch = 0
    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)
        if "model_D" in checkpoint:
            logger.info("Loading model_D from {}".format(cfg.resume))
            model_D_weights = checkpoint[
                'model_D'] if distributed else strip_prefix_if_present(
                    checkpoint['model_D'], 'module.')
            model_D.load_state_dict(model_D_weights)
        # if "optimizer_fea" in checkpoint:
        #     logger.info("Loading optimizer_fea from {}".format(cfg.resume))
        #     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
        # if "optimizer_cls" in checkpoint:
        #     logger.info("Loading optimizer_cls from {}".format(cfg.resume))
        #     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
        # if "optimizer_D" in checkpoint:
        #     logger.info("Loading optimizer_D from {}".format(cfg.resume))
        #     optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        # if "iteration" in checkpoint:
        #     iteration = checkpoint['iteration']

    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)

    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(
            tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None

    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=batch_size,
        shuffle=(src_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=src_train_sampler,
        drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(
        tgt_train_data,
        batch_size=batch_size,
        shuffle=(tgt_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=tgt_train_sampler,
        drop_last=True)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
    bce_loss = torch.nn.BCELoss(reduction='none')

    max_iters = cfg.SOLVER.MAX_ITER
    source_label = 0
    target_label = 1
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    model_D.train()
    start_training_time = time.time()
    end = time.time()
    for i, ((src_input, src_label, src_name),
            (tgt_input, _,
             _)) in enumerate(zip(src_train_loader, tgt_train_loader)):
        #             torch.distributed.barrier()
        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        current_lr_D = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                            cfg.SOLVER.BASE_LR_D,
                                            iteration,
                                            max_iters,
                                            power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10
        for index in range(len(optimizer_D.param_groups)):
            optimizer_D.param_groups[index]['lr'] = current_lr_D


#       torch.distributed.barrier()

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        optimizer_D.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        tgt_size = tgt_input.shape[-2:]
        inp = torch.cat(
            [src_input, F.interpolate(tgt_input.detach(), src_size)])
        # try:
        src_fea = feature_extractor(inp)[:batch_size]

        src_pred = classifier(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        loss_seg = criterion(src_pred, src_label)
        loss_seg.backward()

        # torch.distributed.barrier()

        # generate soft labels
        src_soft_label = F.softmax(src_pred, dim=1).detach()
        src_soft_label[src_soft_label > 0.9] = 0.9

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred = tgt_pred.div(temperature)
        tgt_soft_label = F.softmax(tgt_pred, dim=1)

        tgt_soft_label = tgt_soft_label.detach()
        tgt_soft_label[tgt_soft_label > 0.9] = 0.9

        tgt_D_pred = model_D(tgt_fea, tgt_size)
        loss_adv_tgt = 0.001 * soft_label_cross_entropy(
            tgt_D_pred,
            torch.cat(
                (tgt_soft_label, torch.zeros_like(tgt_soft_label)), dim=1))
        loss_adv_tgt.backward()

        optimizer_fea.step()
        optimizer_cls.step()

        optimizer_D.zero_grad()
        # torch.distributed.barrier()

        src_D_pred = model_D(src_fea.detach(), src_size)
        loss_D_src = 0.5 * soft_label_cross_entropy(
            src_D_pred,
            torch.cat(
                (src_soft_label, torch.zeros_like(src_soft_label)), dim=1))
        loss_D_src.backward()

        tgt_D_pred = model_D(tgt_fea.detach(), tgt_size)
        loss_D_tgt = 0.5 * soft_label_cross_entropy(
            tgt_D_pred,
            torch.cat(
                (torch.zeros_like(tgt_soft_label), tgt_soft_label), dim=1))
        loss_D_tgt.backward()

        # torch.distributed.barrier()

        optimizer_D.step()

        meters.update(loss_seg=loss_seg.item())
        meters.update(loss_adv_tgt=loss_adv_tgt.item())
        meters.update(loss_D=(loss_D_src.item() + loss_D_tgt.item()))
        meters.update(loss_D_src=loss_D_src.item())
        meters.update(loss_D_tgt=loss_D_tgt.item())

        iteration = iteration + 1

        n = src_input.size(0)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if (iteration == cfg.SOLVER.MAX_ITER or iteration %
                cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'model_D': model_D.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict(),
                    'optimizer_D': optimizer_D.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (cfg.SOLVER.MAX_ITER)))

    return feature_extractor, classifier
Esempio n. 3
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("SelfSupervised.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if local_rank == 0:
        print(feature_extractor)
        print(classifier)

    batch_size = cfg.SOLVER.BATCH_SIZE
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(cfg.SOLVER.BATCH_SIZE /
                         torch.distributed.get_world_size())
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    src_train_data = build_dataset(cfg, mode='train', is_source=True)

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(src_train_data,
                                               batch_size=batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=4,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    ce_criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    max_iters = cfg.SOLVER.MAX_ITER
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    start_training_time = time.time()
    end = time.time()

    for i, (src_input, src_label, _) in enumerate(train_loader):
        data_time = time.time() - end
        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        size = src_label.shape[-2:]
        pred = classifier(feature_extractor(src_input), size)
        loss = ce_criterion(pred, src_label)
        loss.backward()

        optimizer_fea.step()
        optimizer_cls.step()
        meters.update(loss_seg=loss.item())
        iteration += 1

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)
        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.2f} GB",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 /
                    1024.0 / 1024.0,
                ))

        if (iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0
                or iteration == max_iters) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / cfg.SOLVER.STOP_ITER))

    return feature_extractor, classifier
Esempio n. 4
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("BCDM.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    classifier_2 = build_classifier(cfg)
    classifier_2.to(device)

    if local_rank == 0:
        print(feature_extractor)

    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(
            cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        pg3 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier_2 = torch.nn.parallel.DistributedDataParallel(
            classifier_2,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg3)

        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(list(classifier.parameters()) +
                                    list(classifier_2.parameters()),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)
        classifier_2_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier_2'], 'module.')
        classifier_2.load_state_dict(classifier_2_weights)
        # if "optimizer_fea" in checkpoint:
        #     logger.info("Loading optimizer_fea from {}".format(cfg.resume))
        #     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
        # if "optimizer_cls" in checkpoint:
        #     logger.info("Loading optimizer_cls from {}".format(cfg.resume))
        #     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
        # if "iteration" in checkpoint:
        #     iteration = checkpoint['iteration']

    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)

    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(
            tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None

    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=batch_size,
        shuffle=(src_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=src_train_sampler,
        drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(
        tgt_train_data,
        batch_size=batch_size,
        shuffle=(tgt_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=tgt_train_sampler,
        drop_last=True)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    max_iters = cfg.SOLVER.MAX_ITER
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    classifier_2.train()
    start_training_time = time.time()
    end = time.time()
    for i, ((src_input, src_label, src_name),
            (tgt_input, _,
             _)) in enumerate(zip(src_train_loader, tgt_train_loader)):
        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        # Step A: train on source (CE loss) and target (Ent loss)
        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        tgt_size = tgt_input.shape[-2:]

        src_fea = feature_extractor(src_input)
        src_pred = classifier(src_fea, src_size)
        src_pred_2 = classifier_2(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        src_pred_2 = src_pred_2.div(temperature)
        # source segmentation loss
        loss_seg = criterion(src_pred, src_label) + criterion(
            src_pred_2, src_label)

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
        tgt_pred = F.softmax(tgt_pred)
        tgt_pred_2 = F.softmax(tgt_pred_2)
        loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
        total_loss = loss_seg + cfg.SOLVER.ENT_LOSS * loss_ent
        total_loss.backward()
        # torch.distributed.barrier()
        optimizer_fea.step()
        optimizer_cls.step()

        # Step B: train bi-classifier to maximize loss_cdd
        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_fea = feature_extractor(src_input)
        src_pred = classifier(src_fea, src_size)
        src_pred_2 = classifier_2(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        src_pred_2 = src_pred_2.div(temperature)
        loss_seg = criterion(src_pred, src_label) + criterion(
            src_pred_2, src_label)

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
        tgt_pred = F.softmax(tgt_pred)
        tgt_pred_2 = F.softmax(tgt_pred_2)
        loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
        loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2)
        total_loss = loss_seg - cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent
        total_loss.backward()
        optimizer_cls.step()

        # Step C: train feature extractor to min loss_cdd
        for k in range(cfg.SOLVER.NUM_K):
            optimizer_fea.zero_grad()
            optimizer_cls.zero_grad()
            tgt_fea = feature_extractor(tgt_input)
            tgt_pred = classifier(tgt_fea, tgt_size)
            tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
            tgt_pred = F.softmax(tgt_pred)
            tgt_pred_2 = F.softmax(tgt_pred_2)
            loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
            loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2)
            total_loss = cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent
            total_loss.backward()
            optimizer_fea.zero_grad()

        meters.update(loss_seg=loss_seg.item())
        meters.update(loss_cdd=loss_cdd.item())
        meters.update(loss_ent=loss_ent.item())

        iteration = iteration + 1

        n = src_input.size(0)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if (iteration == cfg.SOLVER.MAX_ITER or iteration %
                cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'classifier_2': classifier_2.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / cfg.SOLVER.MAX_ITER))

    return feature_extractor, classifier, classifier_2