Beispiel #1
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    print(model)
    print("")
    print("##############################################################")
    print("")
    #summary(model, (3, 2048, 1024))

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    #print(f"{bcolors.WARNING}Warning: No active frommets remain. Continue?{bcolors.ENDC}")
    
    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )
    

    return model
Beispiel #2
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    #summary(model,input_size=(2,3,1333,800))

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, cfg.MODEL.CLS_WEIGHT, cfg.MODEL.REG_WEIGHT, init_div=True, init_opti=False, init_model=True)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
Beispiel #3
0
def train(cfg, local_rank, distributed): #cfg 0 False
    model = build_detection_model(cfg) #实例化模型
    device = torch.device(cfg.MODEL.DEVICE) #cfg.MODEL.DEVICE="cuda" 将torch.tensor分配到cuda 即GPU上
    model.to(device) #将模型放在gpu上运行

    if cfg.MODEL.USE_SYNCBN: #False
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model) # 定义网络训练优化器
    scheduler = make_lr_scheduler(cfg, optimizer)  #设置学习率

    if distributed: #False
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR #"."

    save_to_disk = get_rank() == 0 #True
    #checkpoint为网络的预训练模型
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data) #将extra_checkpoint_data字典里的数值加入到arguments字典中

    #make_data_loader
    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed, #False
        start_iter=arguments["iteration"], # 0
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD  #SOLVER.CHECKPOINT_PERIOD = 2500

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
Beispiel #4
0
def do_train(
    model,
    data_loader,
    optimizer,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    arguments,
):
    logger = logging.getLogger("fcos_core.trainer")
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()
    for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
        if not pytorch_1_1_0_or_later:
            scheduler.step()

        images = images.to(device)
        targets = [target.to(device) for target in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())
        if losses > 1e5:
            import pdb
            pdb.set_trace()

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss=losses_reduced, **loss_dict_reduced)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if pytorch_1_1_0_or_later:
            scheduler.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
        if iteration == max_iter:
            checkpointer.save("model_final", **arguments)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Beispiel #5
0
    def __init__(self, cfg, local_rank, distributed):
        self.writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
        self.start_epoch = 0
        # self.epochs = cfg.MAX_ITER / len()
        self.epochs = 5
        model = build_detection_model(cfg)
        device = torch.device(cfg.MODEL.DEVICE)
        model.to(device)

        if cfg.MODEL.USE_SYNCBN:
            assert is_pytorch_1_1_0_or_later(), \
                "SyncBatchNorm is only available in pytorch >= 1.1.0"
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        optimizer = make_optimizer(cfg, model)
        scheduler = make_lr_scheduler(cfg, optimizer)

        if distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                # this should be removed if we update BatchNorm stats
                broadcast_buffers=False,
            )

        arguments = {}
        arguments["iteration"] = 0

        output_dir = cfg.OUTPUT_DIR

        save_to_disk = get_rank() == 0
        checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                             output_dir, save_to_disk)
        extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
        arguments.update(extra_checkpoint_data)

        # 核心修改在于dataset,dataloader都是torch.utils.data.data_loader
        # import pdb; pdb.set_trace()
        # train_loader = build_single_data_loader(cfg)
        self.train_loader = make_train_loader(
            cfg, start_iter=arguments["iteration"])
        # self.val_loader = make_val_loader(cfg)
        # train_data_loader = make_data_loader(
        #     cfg,
        #     is_train=True,
        #     is_distributed=distributed,
        #     start_iter=arguments["iteration"],
        # )

        checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.checkpointer = checkpointer
        self.scheduler = scheduler
        self.device = device
        self.checkpoint_period = checkpoint_period
        self.arguments = arguments
        self.distributed = distributed
Beispiel #6
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    # import matplotlib.pyplot as plt
    # import numpy as np
    #
    # def imshow(img):
    #     #img = img / 2 + 0.5  # unnormalize
    #     img = img + 115
    #     img = img[[2, 1, 0]]
    #     npimg = img.numpy().astype(np.int)
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # import torchvision
    # dataiter = iter(data_loader)
    # images, target, _ = dataiter.next()  #chwangteg target and pixel is hundreds
    #
    # imshow(torchvision.utils.make_grid(images.tensors))

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
def train(cfg, local_rank, distributed, labelenc_fpath):
    model = LabelEncStep2Network(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    # Load LabelEncodingFunction
    # Initialize FPN and Head from Step1 weights
    if not checkpointer.has_checkpoint():
        labelenc_weights = torch.load(labelenc_fpath, map_location=torch.device('cpu'))
        # load LabelEncodingFunction
        model.module.label_encoding_function.load_state_dict(
                labelenc_weights['label_encoding_function'], strict=True)
        # Initialize Head
        model.module.rpn.load_state_dict(
                labelenc_weights['rpn'], strict=True)
        if model.module.roi_heads:
            model.module.roi_heads.load_state_dict(
                labelenc_weights['roi_heads'], strict=True)
        # Initialize FPN
        fpn_weight = model.module.label_encoding_function.fpn.state_dict()
        model.module.backbone.fpn.load_state_dict(fpn_weight, strict=True)


    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
Beispiel #8
0
def do_train(
    model,
    arch,
    data_loader,
    val_loader,
    optimizer,
    alpha_optim,
    scheduler,
    checkpointer,
    device,
    checkpoint_period,
    arguments,
    cfg,
    tb_info={},
    first_order=True,
):
    logger = logging.getLogger("fad_core.trainer")
    logger.info("Start the architecture search")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    model.train()
    start_training_time = time.time()
    end = time.time()
    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()

    Genotype = model.genotype()
    iteration = 0
    for n_m, genotype in enumerate(Genotype):
        logger.info("genotype = {}".format(genotype))

    for iteration, ((images, targets, _),
                    (images_val, targets_val,
                     _)) in enumerate(zip(data_loader, val_loader),
                                      start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        if len(targets) == cfg.SOLVER.IMS_PER_BATCH and len(
                targets_val) == cfg.SOLVER.IMS_PER_BATCH:

            images = images.to(device)
            targets = [target.to(device) for target in targets]
            images_val = images_val.to(device)
            targets_val = [target.to(device) for target in targets_val]

            # -------------- update alpha
            lr = scheduler.get_lr()[0]
            alpha_optim.zero_grad()

            if not first_order:
                # ----- 2nd order
                arch.unrolled_backward(images, targets, images_val,
                                       targets_val, lr, optimizer)
            else:
                # ----- 1st order
                arch.first_order_backward(images_val, targets_val)
            alpha_optim.step()

            # --------------- update w
            loss_dict = model(images, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            for lkey, lval in loss_dict_reduced.items():
                loss_dict_reduced[lkey] = lval.mean()
            meters.update(loss=losses_reduced.mean(), **loss_dict_reduced)

            # --------- tensorboard logger
            tb_logger = tb_info.get('tb_logger', None)
            if tb_logger:
                tb_prefix = '{}loss'.format(tb_info['prefix'])
                tb_logger.add_scalar(tb_prefix, losses_reduced.mean(),
                                     iteration)

                for key, value in loss_dict_reduced.items():
                    tb_prefix = "{}{}".format(tb_info['prefix'], key)
                    tb_logger.add_scalar(tb_prefix, value, iteration)

                tb_prefix = '{}loss'.format(tb_info['prefix'])
                tb_logger.add_scalar(tb_prefix + '_z_lr', lr, iteration)

            optimizer.zero_grad()
            losses.mean().backward()
            torch.nn.utils.clip_grad_norm_(model.weights(), 20)
            optimizer.step()

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            eta_seconds = meters.time.global_avg * (max_iter - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % 20 == 0 or iteration == max_iter:
                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]).format(
                        eta=eta_string,
                        iter=iteration,
                        meters=str(meters),
                        lr=optimizer.param_groups[0]["lr"],
                        memory=torch.cuda.max_memory_allocated() / 1024.0 /
                        1024.0,
                    ))

            if iteration % (checkpoint_period) == 0:
                checkpointer.save("model_{:07d}".format(iteration),
                                  **arguments)
            if iteration == max_iter:
                checkpointer.save("model_final", **arguments)

            # ---------- save genotype
            if cfg.MODEL.FAD.PLOT and (iteration % checkpoint_period == 0):

                Genotype = model.genotype()
                fw = open(f"{cfg.OUTPUT_DIR}/genotype.log", "w")
                for n_m, genotype in enumerate(Genotype):
                    logger.info("genotype = {}".format(genotype))
                    # write genotype for augment
                    fw.write(f"{genotype}\n")

                    # genotype as a image
                    plot_path = os.path.join(cfg.OUTPUT_DIR + '/plots',
                                             "Module%d" % n_m,
                                             "Iter{:06d}".format(iteration))
                    caption = "Iteration {}".format(iteration)
                    plot(genotype.normal, plot_path + "-normal", caption)
                model.print_alphas(logger)
                fw.close()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Beispiel #9
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)  # 利用build_detection_model构建model
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:  # syncbn是什么,SyncBatchNorm
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
            model)  # 对model进行转换,转换成sync的

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:  # 是否使用分布式训练,distributed 分布式的
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}  # 创建字典
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg,
        model,
        optimizer,
        scheduler,
        output_dir,
        save_to_disk  # checkpoint
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(  # 利用make_data_loader读取数据
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
Beispiel #10
0
def train(cfg, local_rank, distributed):
    writer = SummaryWriter('runs/{}'.format(cfg.OUTPUT_DIR))
    ##########################################################################
    ############################# Initial Model ##############################
    ##########################################################################
    model = {}
    device = torch.device(cfg.MODEL.DEVICE)

    backbone = build_backbone(cfg).to(device)
    fcos = build_rpn(cfg, backbone.out_channels).to(device)

    if cfg.MODEL.ADV.USE_DIS_GLOBAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            dis_P7 = FCOSDiscriminator(
                num_convs=cfg.MODEL.ADV.DIS_P7_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P7,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P6:
            dis_P6 = FCOSDiscriminator(
                num_convs=cfg.MODEL.ADV.DIS_P6_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P6,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P5:
            dis_P5 = FCOSDiscriminator(
                num_convs=cfg.MODEL.ADV.DIS_P5_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P5,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P4:
            dis_P4 = FCOSDiscriminator(
                num_convs=cfg.MODEL.ADV.DIS_P4_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P4,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P3:
            dis_P3 = FCOSDiscriminator(
                num_convs=cfg.MODEL.ADV.DIS_P3_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.GRL_WEIGHT_P3,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)

    if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE:
        if cfg.MODEL.ADV.USE_DIS_P7:
            dis_P7_CA = FCOSDiscriminator_CA(
                num_convs=cfg.MODEL.ADV.CA_DIS_P7_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P7,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P6:
            dis_P6_CA = FCOSDiscriminator_CA(
                num_convs=cfg.MODEL.ADV.CA_DIS_P6_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P6,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P5:
            dis_P5_CA = FCOSDiscriminator_CA(
                num_convs=cfg.MODEL.ADV.CA_DIS_P5_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P5,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P4:
            dis_P4_CA = FCOSDiscriminator_CA(
                num_convs=cfg.MODEL.ADV.CA_DIS_P4_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P4,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P3:
            dis_P3_CA = FCOSDiscriminator_CA(
                num_convs=cfg.MODEL.ADV.CA_DIS_P3_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.CA_GRL_WEIGHT_P3,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)

    if cfg.MODEL.ADV.USE_DIS_CONDITIONAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            dis_P7_Cond = FCOSDiscriminator_CondA(
                num_convs=cfg.MODEL.ADV.COND_DIS_P7_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P7,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN,
                class_align=cfg.MODEL.ADV.COND_CLASS,
                reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT,
                reg_top_align=cfg.MODEL.ADV.COND_REG.TOP,
                expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device)
        if cfg.MODEL.ADV.USE_DIS_P6:
            dis_P6_Cond = FCOSDiscriminator_CondA(
                num_convs=cfg.MODEL.ADV.COND_DIS_P6_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P6,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN,
                class_align=cfg.MODEL.ADV.COND_CLASS,
                reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT,
                reg_top_align=cfg.MODEL.ADV.COND_REG.TOP,
                expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device)
        if cfg.MODEL.ADV.USE_DIS_P5:
            dis_P5_Cond = FCOSDiscriminator_CondA(
                num_convs=cfg.MODEL.ADV.COND_DIS_P5_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P5,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN,
                class_align=cfg.MODEL.ADV.COND_CLASS,
                reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT,
                reg_top_align=cfg.MODEL.ADV.COND_REG.TOP,
                expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device)
        if cfg.MODEL.ADV.USE_DIS_P4:
            dis_P4_Cond = FCOSDiscriminator_CondA(
                num_convs=cfg.MODEL.ADV.COND_DIS_P4_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P4,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN,
                class_align=cfg.MODEL.ADV.COND_CLASS,
                reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT,
                reg_top_align=cfg.MODEL.ADV.COND_REG.TOP,
                expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device)
        if cfg.MODEL.ADV.USE_DIS_P3:
            dis_P3_Cond = FCOSDiscriminator_CondA(
                num_convs=cfg.MODEL.ADV.COND_DIS_P3_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.COND_GRL_WEIGHT_P3,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                # center_aware_type=cfg.MODEL.ADV.CENTER_AWARE_TYPE,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN,
                class_align=cfg.MODEL.ADV.COND_CLASS,
                reg_left_align=cfg.MODEL.ADV.COND_REG.LEFT,
                reg_top_align=cfg.MODEL.ADV.COND_REG.TOP,
                expand_dim=cfg.MODEL.ADV.COND_EXPAND).to(device)

    if cfg.MODEL.ADV.USE_DIS_HEAD:
        if cfg.MODEL.ADV.USE_DIS_P7:
            dis_P7_HA = FCOSDiscriminator_HA(
                num_convs=cfg.MODEL.ADV.HA_DIS_P7_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P7,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P6:
            dis_P6_HA = FCOSDiscriminator_HA(
                num_convs=cfg.MODEL.ADV.HA_DIS_P6_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P6,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P5:
            dis_P5_HA = FCOSDiscriminator_HA(
                num_convs=cfg.MODEL.ADV.HA_DIS_P5_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P5,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P4:
            dis_P4_HA = FCOSDiscriminator_HA(
                num_convs=cfg.MODEL.ADV.HA_DIS_P4_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P4,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)
        if cfg.MODEL.ADV.USE_DIS_P3:
            dis_P3_HA = FCOSDiscriminator_HA(
                num_convs=cfg.MODEL.ADV.HA_DIS_P3_NUM_CONVS,
                grad_reverse_lambda=cfg.MODEL.ADV.HA_GRL_WEIGHT_P3,
                center_aware_weight=cfg.MODEL.ADV.CENTER_AWARE_WEIGHT,
                grl_applied_domain=cfg.MODEL.ADV.GRL_APPLIED_DOMAIN).to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
        fcos = torch.nn.SyncBatchNorm.convert_sync_batchnorm(fcos)

        if cfg.MODEL.ADV.USE_DIS_GLOBAL:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P7)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P6)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P5)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P4)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dis_P3)

        if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P7_CA)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P6_CA)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P5_CA)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P4_CA)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3_CA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P3_CA)

        if cfg.MODEL.ADV.USE_DIS_CONDITIONAL:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P7_Cond)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P6_Cond)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P5_Cond)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P4_Cond)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3_Cond = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P3_Cond)

        if cfg.MODEL.ADV.USE_DIS_HEAD:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P7_HA)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P6_HA)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P5_HA)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P4_HA)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3_HA = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    dis_P3_HA)

    ##########################################################################
    #################### Initial Optimizer and Scheduler #####################
    ##########################################################################
    optimizer = {}
    optimizer["backbone"] = make_optimizer(cfg, backbone, name='backbone')
    optimizer["fcos"] = make_optimizer(cfg, fcos, name='fcos')

    if cfg.MODEL.ADV.USE_DIS_GLOBAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            optimizer["dis_P7"] = make_optimizer(cfg,
                                                 dis_P7,
                                                 name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            optimizer["dis_P6"] = make_optimizer(cfg,
                                                 dis_P6,
                                                 name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            optimizer["dis_P5"] = make_optimizer(cfg,
                                                 dis_P5,
                                                 name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            optimizer["dis_P4"] = make_optimizer(cfg,
                                                 dis_P4,
                                                 name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            optimizer["dis_P3"] = make_optimizer(cfg,
                                                 dis_P3,
                                                 name='discriminator')

    if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE:
        if cfg.MODEL.ADV.USE_DIS_P7:
            optimizer["dis_P7_CA"] = make_optimizer(cfg,
                                                    dis_P7_CA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            optimizer["dis_P6_CA"] = make_optimizer(cfg,
                                                    dis_P6_CA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            optimizer["dis_P5_CA"] = make_optimizer(cfg,
                                                    dis_P5_CA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            optimizer["dis_P4_CA"] = make_optimizer(cfg,
                                                    dis_P4_CA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            optimizer["dis_P3_CA"] = make_optimizer(cfg,
                                                    dis_P3_CA,
                                                    name='discriminator')

    if cfg.MODEL.ADV.USE_DIS_CONDITIONAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            optimizer["dis_P7_Cond"] = make_optimizer(cfg,
                                                      dis_P7_Cond,
                                                      name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            optimizer["dis_P6_Cond"] = make_optimizer(cfg,
                                                      dis_P6_Cond,
                                                      name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            optimizer["dis_P5_Cond"] = make_optimizer(cfg,
                                                      dis_P5_Cond,
                                                      name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            optimizer["dis_P4_Cond"] = make_optimizer(cfg,
                                                      dis_P4_Cond,
                                                      name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            optimizer["dis_P3_Cond"] = make_optimizer(cfg,
                                                      dis_P3_Cond,
                                                      name='discriminator')

    if cfg.MODEL.ADV.USE_DIS_HEAD:
        if cfg.MODEL.ADV.USE_DIS_P7:
            optimizer["dis_P7_HA"] = make_optimizer(cfg,
                                                    dis_P7_HA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            optimizer["dis_P6_HA"] = make_optimizer(cfg,
                                                    dis_P6_HA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            optimizer["dis_P5_HA"] = make_optimizer(cfg,
                                                    dis_P5_HA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            optimizer["dis_P4_HA"] = make_optimizer(cfg,
                                                    dis_P4_HA,
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            optimizer["dis_P3_HA"] = make_optimizer(cfg,
                                                    dis_P3_HA,
                                                    name='discriminator')

    scheduler = {}
    scheduler["backbone"] = make_lr_scheduler(cfg,
                                              optimizer["backbone"],
                                              name='backbone')
    scheduler["fcos"] = make_lr_scheduler(cfg, optimizer["fcos"], name='fcos')

    if cfg.MODEL.ADV.USE_DIS_GLOBAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            scheduler["dis_P7"] = make_lr_scheduler(cfg,
                                                    optimizer["dis_P7"],
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            scheduler["dis_P6"] = make_lr_scheduler(cfg,
                                                    optimizer["dis_P6"],
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            scheduler["dis_P5"] = make_lr_scheduler(cfg,
                                                    optimizer["dis_P5"],
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            scheduler["dis_P4"] = make_lr_scheduler(cfg,
                                                    optimizer["dis_P4"],
                                                    name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            scheduler["dis_P3"] = make_lr_scheduler(cfg,
                                                    optimizer["dis_P3"],
                                                    name='discriminator')

    if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE:
        if cfg.MODEL.ADV.USE_DIS_P7:
            scheduler["dis_P7_CA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P7_CA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            scheduler["dis_P6_CA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P6_CA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            scheduler["dis_P5_CA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P5_CA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            scheduler["dis_P4_CA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P4_CA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            scheduler["dis_P3_CA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P3_CA"],
                                                       name='discriminator')

    if cfg.MODEL.ADV.USE_DIS_CONDITIONAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            scheduler["dis_P7_Cond"] = make_lr_scheduler(
                cfg, optimizer["dis_P7_Cond"], name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            scheduler["dis_P6_Cond"] = make_lr_scheduler(
                cfg, optimizer["dis_P6_Cond"], name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            scheduler["dis_P5_Cond"] = make_lr_scheduler(
                cfg, optimizer["dis_P5_Cond"], name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            scheduler["dis_P4_Cond"] = make_lr_scheduler(
                cfg, optimizer["dis_P4_Cond"], name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            scheduler["dis_P3_Cond"] = make_lr_scheduler(
                cfg, optimizer["dis_P3_Cond"], name='discriminator')

    if cfg.MODEL.ADV.USE_DIS_HEAD:
        if cfg.MODEL.ADV.USE_DIS_P7:
            scheduler["dis_P7_HA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P7_HA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P6:
            scheduler["dis_P6_HA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P6_HA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P5:
            scheduler["dis_P5_HA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P5_HA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P4:
            scheduler["dis_P4_HA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P4_HA"],
                                                       name='discriminator')
        if cfg.MODEL.ADV.USE_DIS_P3:
            scheduler["dis_P3_HA"] = make_lr_scheduler(cfg,
                                                       optimizer["dis_P3_HA"],
                                                       name='discriminator')

    ##########################################################################
    ######################## DistributedDataParallel #########################
    ##########################################################################
    if distributed:
        backbone = torch.nn.parallel.DistributedDataParallel(
            backbone,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False)
        fcos = torch.nn.parallel.DistributedDataParallel(
            fcos,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False)

        if cfg.MODEL.ADV.USE_DIS_GLOBAL:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7 = torch.nn.parallel.DistributedDataParallel(
                    dis_P7,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6 = torch.nn.parallel.DistributedDataParallel(
                    dis_P6,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5 = torch.nn.parallel.DistributedDataParallel(
                    dis_P5,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4 = torch.nn.parallel.DistributedDataParallel(
                    dis_P4,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3 = torch.nn.parallel.DistributedDataParallel(
                    dis_P3,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)

        if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7_CA = torch.nn.parallel.DistributedDataParallel(
                    dis_P7_CA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6_CA = torch.nn.parallel.DistributedDataParallel(
                    dis_P6_CA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5_CA = torch.nn.parallel.DistributedDataParallel(
                    dis_P5_CA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4_CA = torch.nn.parallel.DistributedDataParallel(
                    dis_P4_CA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3_CA = torch.nn.parallel.DistributedDataParallel(
                    dis_P3_CA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)

        if cfg.MODEL.ADV.USE_DIS_CONDITIONAL:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7_Cond = torch.nn.parallel.DistributedDataParallel(
                    dis_P7_Cond,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6_Cond = torch.nn.parallel.DistributedDataParallel(
                    dis_P6_Cond,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5_Cond = torch.nn.parallel.DistributedDataParallel(
                    dis_P5_Cond,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4_Cond = torch.nn.parallel.DistributedDataParallel(
                    dis_P4_Cond,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3_Cond = torch.nn.parallel.DistributedDataParallel(
                    dis_P3_Cond,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)

        if cfg.MODEL.ADV.USE_DIS_HEAD:
            if cfg.MODEL.ADV.USE_DIS_P7:
                dis_P7_HA = torch.nn.parallel.DistributedDataParallel(
                    dis_P7_HA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P6:
                dis_P6_HA = torch.nn.parallel.DistributedDataParallel(
                    dis_P6_HA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P5:
                dis_P5_HA = torch.nn.parallel.DistributedDataParallel(
                    dis_P5_HA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P4:
                dis_P4_HA = torch.nn.parallel.DistributedDataParallel(
                    dis_P4_HA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)
            if cfg.MODEL.ADV.USE_DIS_P3:
                dis_P3_HA = torch.nn.parallel.DistributedDataParallel(
                    dis_P3_HA,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    # this should be removed if we update BatchNorm stats
                    broadcast_buffers=False)

    ##########################################################################
    ########################### Save Model to Dict ###########################
    ##########################################################################
    model["backbone"] = backbone
    model["fcos"] = fcos

    if cfg.MODEL.ADV.USE_DIS_GLOBAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            model["dis_P7"] = dis_P7
        if cfg.MODEL.ADV.USE_DIS_P6:
            model["dis_P6"] = dis_P6
        if cfg.MODEL.ADV.USE_DIS_P5:
            model["dis_P5"] = dis_P5
        if cfg.MODEL.ADV.USE_DIS_P4:
            model["dis_P4"] = dis_P4
        if cfg.MODEL.ADV.USE_DIS_P3:
            model["dis_P3"] = dis_P3

    if cfg.MODEL.ADV.USE_DIS_CENTER_AWARE:
        if cfg.MODEL.ADV.USE_DIS_P7:
            model["dis_P7_CA"] = dis_P7_CA
        if cfg.MODEL.ADV.USE_DIS_P6:
            model["dis_P6_CA"] = dis_P6_CA
        if cfg.MODEL.ADV.USE_DIS_P5:
            model["dis_P5_CA"] = dis_P5_CA
        if cfg.MODEL.ADV.USE_DIS_P4:
            model["dis_P4_CA"] = dis_P4_CA
        if cfg.MODEL.ADV.USE_DIS_P3:
            model["dis_P3_CA"] = dis_P3_CA

    if cfg.MODEL.ADV.USE_DIS_CONDITIONAL:
        if cfg.MODEL.ADV.USE_DIS_P7:
            model["dis_P7_Cond"] = dis_P7_Cond
        if cfg.MODEL.ADV.USE_DIS_P6:
            model["dis_P6_Cond"] = dis_P6_Cond
        if cfg.MODEL.ADV.USE_DIS_P5:
            model["dis_P5_Cond"] = dis_P5_Cond
        if cfg.MODEL.ADV.USE_DIS_P4:
            model["dis_P4_Cond"] = dis_P4_Cond
        if cfg.MODEL.ADV.USE_DIS_P3:
            model["dis_P3_Cond"] = dis_P3_Cond

    if cfg.MODEL.ADV.USE_DIS_HEAD:
        if cfg.MODEL.ADV.USE_DIS_P7:
            model["dis_P7_HA"] = dis_P7_HA
        if cfg.MODEL.ADV.USE_DIS_P6:
            model["dis_P6_HA"] = dis_P6_HA
        if cfg.MODEL.ADV.USE_DIS_P5:
            model["dis_P5_HA"] = dis_P5_HA
        if cfg.MODEL.ADV.USE_DIS_P4:
            model["dis_P4_HA"] = dis_P4_HA
        if cfg.MODEL.ADV.USE_DIS_P3:
            model["dis_P3_HA"] = dis_P3_HA

    ##########################################################################
    ################################ Training ################################
    ##########################################################################
    arguments = {}
    arguments["iteration"] = 0
    arguments["use_dis_global"] = cfg.MODEL.ADV.USE_DIS_GLOBAL
    arguments["use_dis_ca"] = cfg.MODEL.ADV.USE_DIS_CENTER_AWARE
    arguments["use_dis_conditional"] = cfg.MODEL.ADV.USE_DIS_CONDITIONAL
    arguments["use_dis_ha"] = cfg.MODEL.ADV.USE_DIS_HEAD
    arguments["ga_dis_lambda"] = cfg.MODEL.ADV.GA_DIS_LAMBDA
    arguments["ca_dis_lambda"] = cfg.MODEL.ADV.CA_DIS_LAMBDA
    arguments["cond_dis_lambda"] = cfg.MODEL.ADV.COND_DIS_LAMBDA
    arguments["ha_dis_lambda"] = cfg.MODEL.ADV.HA_DIS_LAMBDA

    arguments["use_feature_layers"] = []
    if cfg.MODEL.ADV.USE_DIS_P7:
        arguments["use_feature_layers"].append("P7")
    if cfg.MODEL.ADV.USE_DIS_P6:
        arguments["use_feature_layers"].append("P6")
    if cfg.MODEL.ADV.USE_DIS_P5:
        arguments["use_feature_layers"].append("P5")
    if cfg.MODEL.ADV.USE_DIS_P4:
        arguments["use_feature_layers"].append("P4")
    if cfg.MODEL.ADV.USE_DIS_P3:
        arguments["use_feature_layers"].append("P3")

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(f=cfg.MODEL.WEIGHT,
                                              load_dis=True,
                                              load_opt_sch=False)
    # arguments.update(extra_checkpoint_data)

    # Initial dataloader (both target and source domain)
    data_loader = {}
    data_loader["source"] = make_data_loader_source(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )
    data_loader["target"] = make_data_loader_target(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(model, data_loader, optimizer, scheduler, checkpointer, device,
             checkpoint_period, arguments, cfg, run_test, distributed, writer)

    return model
Beispiel #11
0
def train(cfg, local_rank, distributed, iter_clear, ignore_head):
    model = build_detection_model(cfg)
    # model, conversion_count = convert_to_shift_dbg(
    #         model,
    #         cfg.DEEPSHIFT_DEPTH,
    #         cfg.DEEPSHIFT_TYPE,
    #         convert_weights=True,
    #         use_kernel=cfg.DEEPSHIFT_USEKERNEL,
    #         rounding=cfg.DEEPSHIFT_ROUNDING,
    #         shift_range=cfg.DEEPSHIFT_RANGE)

    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    output_dir = cfg.OUTPUT_DIR
    save_to_disk = get_rank() == 0
    if iter_clear:
        load_opt = False
        load_sch = False
    else:
        load_opt = True
        load_sch = True
    if ignore_head:
        load_body = True
        load_fpn = True
        load_head = False
    else:
        load_body = True
        load_fpn = True
        load_head = True
    # 预加载模型或者是通常的模型,或者是deepshift模型
    if cfg.MODEL.WEIGHT:
        checkpointer = DetectronCheckpointer(
            cfg, model, None, None, output_dir, save_to_disk
        )

        extra_checkpoint_data = checkpointer.load(
            cfg.MODEL.WEIGHT, load_opt=False, load_sch=False,
            load_body=load_body, load_fpn=load_fpn, load_head=load_head)
        
        model, conversion_count = convert_to_shift(
            model,
            cfg.DEEPSHIFT_DEPTH,
            cfg.DEEPSHIFT_TYPE,
            convert_weights=True,
            use_kernel=cfg.DEEPSHIFT_USEKERNEL,
            rounding=cfg.DEEPSHIFT_ROUNDING,
            shift_range=cfg.DEEPSHIFT_RANGE)
        
        optimizer = make_optimizer(cfg, model)
        scheduler = make_lr_scheduler(cfg, optimizer)

        checkpointer = DetectronCheckpointer(
            cfg, model, optimizer, scheduler, output_dir, save_to_disk
        )
    else:
        model, conversion_count = convert_to_shift(
            model,
            cfg.DEEPSHIFT_DEPTH,
            cfg.DEEPSHIFT_TYPE,
            convert_weights=True,
            use_kernel=cfg.DEEPSHIFT_USEKERNEL,
            rounding=cfg.DEEPSHIFT_ROUNDING,
            shift_range=cfg.DEEPSHIFT_RANGE)
        
        optimizer = make_optimizer(cfg, model)
        scheduler = make_lr_scheduler(cfg, optimizer)

        checkpointer = DetectronCheckpointer(
            cfg, model, optimizer, scheduler, output_dir, save_to_disk
        )

        extra_checkpoint_data = checkpointer.load(
            cfg.MODEL.WEIGHT, load_opt=False, load_sch=False,
            load_body=load_body, load_fpn=load_fpn, load_head=load_head)
    
    conv2d_layers_count = count_layer_type(model, torch.nn.Conv2d)
    linear_layers_count = count_layer_type(model, torch.nn.Linear)
    print("###### conversion_count: {}, not convert conv2d layer: {}, linear layer: {}".format(
        conversion_count, conv2d_layers_count, linear_layers_count))

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    arguments.update(extra_checkpoint_data)

    if iter_clear:
        arguments["iteration"] = 0

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    model = round_shift_weights(model)
    torch.save({"model": model.state_dict()}, os.path.join(output_dir, "model_final_round.pth"))

    return model
Beispiel #12
0
def do_train(model, data_loader, optimizer, scheduler, checkpointer, device,
             checkpoint_period, arguments, cfg, run_test, distributed, writer,
             seperate_dis):
    USE_DIS_GLOBAL = arguments["use_dis_global"]
    USE_DIS_CENTER_AWARE = arguments["use_dis_ca"]
    USE_DIS_CONDITIONAL = arguments["use_dis_conditional"]
    USE_DIS_HEAD = arguments["use_dis_ha"]
    used_feature_layers = arguments["use_feature_layers"]
    used_feature_layers = ['P7', 'P6', 'P5', 'P4', 'P3']

    # dataloader
    data_loader_source = data_loader["source"]
    data_loader_target = data_loader["target"]

    # classified label of source domain and target domain
    source_label = 1.0
    target_label = 0.0

    # dis_lambda
    if USE_DIS_GLOBAL:
        ga_dis_lambda = arguments["ga_dis_lambda"]
    if USE_DIS_CENTER_AWARE:
        ca_dis_lambda = arguments["ca_dis_lambda"]
    if USE_DIS_CONDITIONAL:
        cond_dis_lambda = arguments["cond_dis_lambda"]
    if USE_DIS_HEAD:
        ha_dis_lambda = arguments["ha_dis_lambda"]

    # Start training
    logger = logging.getLogger("fcos_core.trainer")
    logger.info("Start training")

    # model.train()
    for k in model:
        model[k].train()

    meters = MetricLogger(delimiter="  ")
    assert len(data_loader_source) == len(data_loader_target)
    max_iter = max(len(data_loader_source), len(data_loader_target))
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()
    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()
    best_map50 = 0.0
    # results = run_test(cfg, model, distributed)
    # exit()
    for iteration, ((images_s, targets_s, _), (images_t, _, _)) \
        in enumerate(zip(data_loader_source, data_loader_target), start_iter):
        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration
        alpha = max(
            1 - iteration / cfg.MODEL.ADV.COND_WARMUP_ITER,
            cfg.MODEL.ADV.COND_ALPHA
        ) if cfg.MODEL.ADV.COND_SMOOTH else cfg.MODEL.ADV.COND_ALPHA
        cf_th = cfg.MODEL.ADV.COND_CONF
        # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
        if not pytorch_1_1_0_or_later:
            # scheduler.step()
            for k in scheduler:
                scheduler[k].step()

        images_s = images_s.to(device)
        targets_s = [target_s.to(device) for target_s in targets_s]
        images_t = images_t.to(device)
        # targets_t = [target_t.to(device) for target_t in targets_t]

        # optimizer.zero_grad()
        for k in optimizer:
            optimizer[k].zero_grad()

        ##########################################################################
        #################### (1): train G with source domain #####################
        ##########################################################################

        loss_dict, features_s, score_maps_s = foward_detector(
            model, images_s, targets=targets_s, return_maps=True)
        labels = loss_dict['labels']
        reg_targets = loss_dict['reg_targets']
        # rename loss to indicate domain
        loss_dict = {k + "_gs": loss_dict[k] for k in loss_dict if 'loss' in k}

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss_gs=losses_reduced, **loss_dict_reduced)

        writer.add_scalar('Loss_FCOS/gs', losses, iteration)
        writer.add_scalar('Loss_FCOS/cls_gs', loss_dict['loss_cls_gs'],
                          iteration)
        writer.add_scalar('Loss_FCOS/reg_gs', loss_dict['loss_reg_gs'],
                          iteration)
        writer.add_scalar('Loss_FCOS/centerness_gs',
                          loss_dict['loss_centerness_gs'], iteration)

        # losses.backward(retain_graph=True)
        del loss_dict, losses

        ##########################################################################
        #################### (2): train D with source domain #####################
        ##########################################################################

        loss_dict = {}
        stat = {}
        for layer in used_feature_layers:
            # detatch score_map
            for map_type in score_maps_s[layer]:
                score_maps_s[layer][map_type] = score_maps_s[layer][
                    map_type].detach()
            if seperate_dis:
                if USE_DIS_GLOBAL:
                    loss_dict["loss_adv_%s_ds" % layer] = \
                        ga_dis_lambda * model["dis_%s" % layer](features_s[layer], source_label, domain='source')
                if USE_DIS_CENTER_AWARE:
                    loss_dict["loss_adv_%s_CA_ds" % layer] = \
                        ca_dis_lambda * model["dis_%s_CA" % layer](features_s[layer], source_label, score_maps_s[layer], domain='source')
                if USE_DIS_CONDITIONAL:
                    loss_cond_l, cur_stat, idx = \
                        model["dis_%s_Cond" % layer](features_s[layer], source_label, score_maps_s[layer], domain='source', alpha=alpha, labels=labels[int(layer[1])-3], reg_targets=reg_targets[int(layer[1])-3], conf_th=cf_th)
                    stat["%s_source_left" %
                         layer] = [s / idx for s in cur_stat]
                    loss_cond_t, cur_idx, idx = \
                        model["dis_%s_Cond_t" % layer](features_s[layer], source_label, score_maps_s[layer], domain='source', alpha=alpha, labels=labels[int(layer[1])-3], reg_targets=reg_targets[int(layer[1])-3], conf_th=cf_th)
                    stat["%s_source_top" % layer] = [s / idx for s in cur_stat]
                    loss_dict["loss_adv_%s_Cond_ds" %
                              layer] = cond_dis_lambda * (loss_cond_l +
                                                          loss_cond_t)
                if USE_DIS_HEAD:
                    loss_dict["loss_adv_%s_HA_ds" % layer] = \
                        ha_dis_lambda * model["dis_%s_HA" % layer](source_label, score_maps_s[layer], domain='source')
            else:
                if USE_DIS_GLOBAL:
                    loss_dict["loss_adv_%s_ds" % layer] = \
                    ga_dis_lambda * model["dis_P7"](features_s[layer], source_label, domain='source')
                if USE_DIS_CENTER_AWARE:
                    loss_dict["loss_adv_%s_CA_ds" % layer] = \
                    ca_dis_lambda * model["dis_P7_CA"](features_s[layer], source_label, score_maps_s[layer], domain='source')
                if USE_DIS_CONDITIONAL:
                    loss_dict["loss_adv_%s_Cond_ds" % layer] = \
                    cond_dis_lambda * model["dis_P7_Cond"](features_s[layer], source_label, score_maps_s[layer], domain='source', alpha=alpha)
                if USE_DIS_HEAD:
                    loss_dict["loss_adv_%s_HA_ds" % layer] = \
                    ha_dis_lambda * model["dis_P7_HA"](source_label, score_maps_s[layer], domain='source')

        losses = sum(loss for loss in loss_dict.values())

        writer.add_scalar('Loss_DISC/ds', losses, iteration)
        if USE_DIS_GLOBAL:
            writer.add_scalar('Loss_DISC/P3_ds', loss_dict['loss_adv_P3_ds'],
                              iteration)
            writer.add_scalar('Loss_DISC/P4_ds', loss_dict['loss_adv_P4_ds'],
                              iteration)
            writer.add_scalar('Loss_DISC/P5_ds', loss_dict['loss_adv_P5_ds'],
                              iteration)
            writer.add_scalar('Loss_DISC/P6_ds', loss_dict['loss_adv_P6_ds'],
                              iteration)
            writer.add_scalar('Loss_DISC/P7_ds', loss_dict['loss_adv_P7_ds'],
                              iteration)
        if USE_DIS_CENTER_AWARE:
            writer.add_scalar('Loss_DISC/P3_CA_ds',
                              loss_dict['loss_adv_P3_CA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P4_CA_ds',
                              loss_dict['loss_adv_P4_CA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P5_CA_ds',
                              loss_dict['loss_adv_P5_CA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P6_CA_ds',
                              loss_dict['loss_adv_P6_CA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P7_CA_ds',
                              loss_dict['loss_adv_P7_CA_ds'], iteration)
        if USE_DIS_CONDITIONAL:
            writer.add_scalar('Loss_DISC/P3_Cond_ds',
                              loss_dict['loss_adv_P3_Cond_ds'], iteration)
            writer.add_scalar('Loss_DISC/P4_Cond_ds',
                              loss_dict['loss_adv_P4_Cond_ds'], iteration)
            writer.add_scalar('Loss_DISC/P5_Cond_ds',
                              loss_dict['loss_adv_P5_Cond_ds'], iteration)
            writer.add_scalar('Loss_DISC/P6_Cond_ds',
                              loss_dict['loss_adv_P6_Cond_ds'], iteration)
            writer.add_scalar('Loss_DISC/P7_Cond_ds',
                              loss_dict['loss_adv_P7_Cond_ds'], iteration)
            for layer in used_feature_layers:
                for i in range(3):
                    writer.add_scalar(
                        'Stat/{}/Source_{}_left'.format(layer, i),
                        stat['%s_source_left' % layer][i], iteration)
                    writer.add_scalar('Stat/{}/Source_{}_top'.format(layer, i),
                                      stat['%s_source_top' % layer][i],
                                      iteration)
        if USE_DIS_HEAD:
            writer.add_scalar('Loss_DISC/P3_HA_ds',
                              loss_dict['loss_adv_P3_HA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P4_HA_ds',
                              loss_dict['loss_adv_P4_HA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P5_HA_ds',
                              loss_dict['loss_adv_P5_HA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P6_HA_ds',
                              loss_dict['loss_adv_P6_HA_ds'], iteration)
            writer.add_scalar('Loss_DISC/P7_HA_ds',
                              loss_dict['loss_adv_P7_HA_ds'], iteration)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss_ds=losses_reduced, **loss_dict_reduced)

        # losses.backward()
        del loss_dict, losses

        ##########################################################################
        #################### (3): train D with target domain #####################
        #################################################################

        loss_dict, features_t, score_maps_t = foward_detector(model,
                                                              images_t,
                                                              return_maps=True)
        assert len(loss_dict) == 1 and loss_dict[
            "zero"] == 0  # loss_dict should be empty dict

        # loss_dict["loss_adv_Pn"] = model_dis_Pn(features_t["Pn"], target_label, domain='target')
        for layer in used_feature_layers:
            # detatch score_map
            for map_type in score_maps_t[layer]:
                score_maps_t[layer][map_type] = score_maps_t[layer][
                    map_type].detach()
            if seperate_dis:
                if USE_DIS_GLOBAL:
                    loss_dict["loss_adv_%s_dt" % layer] = \
                        ga_dis_lambda * model["dis_%s" % layer](features_t[layer], target_label, domain='target')
                if USE_DIS_CENTER_AWARE:
                    loss_dict["loss_adv_%s_CA_dt" %layer] = \
                        ca_dis_lambda * model["dis_%s_CA" % layer](features_t[layer], target_label, score_maps_t[layer], domain='target')
                if USE_DIS_CONDITIONAL:
                    loss_cond_l, cur_stat, idx = \
                         model["dis_%s_Cond" % layer](features_t[layer], target_label, score_maps_t[layer], domain='target', alpha=alpha, conf_th=cf_th)
                    stat["%s_target_left" %
                         layer] = [s / idx for s in cur_stat]
                    loss_cond_t, cur_stat, idx = \
                        model["dis_%s_Cond_t" % layer](features_t[layer], target_label, score_maps_t[layer], domain='target', alpha=alpha, conf_th=cf_th)
                    stat["%s_target_top" % layer] = [s / idx for s in cur_stat]
                    loss_dict["loss_adv_%s_Cond_dt" %
                              layer] = cond_dis_lambda * (loss_cond_l +
                                                          loss_cond_t)
                if USE_DIS_HEAD:
                    loss_dict["loss_adv_%s_HA_dt" %layer] = \
                        ha_dis_lambda * model["dis_%s_HA" % layer](target_label, score_maps_t[layer], domain='target')
            else:
                if USE_DIS_GLOBAL:
                    loss_dict["loss_adv_%s_dt" % layer] = \
                    ga_dis_lambda * model["dis_P7"](features_s[layer], source_label, domain='target')
                if USE_DIS_CENTER_AWARE:
                    loss_dict["loss_adv_%s_CA_dt" % layer] = \
                    ca_dis_lambda * model["dis_P7_CA"](features_s[layer], source_label, score_maps_s[layer], domain='target')
                if USE_DIS_CONDITIONAL:
                    loss_dict["loss_adv_%s_Cond_dt" % layer] = \
                    cond_dis_lambda * model["dis_P7_Cond"](features_s[layer], source_label, score_maps_s[layer], domain='target', alpha=alpha)
                if USE_DIS_HEAD:
                    loss_dict["loss_adv_%s_HA_dt" % layer] = \
                    ha_dis_lambda * model["dis_P7_HA"](source_label, score_maps_s[layer], domain='target')

        losses = sum(loss for loss in loss_dict.values())

        writer.add_scalar('Loss_DISC/dt', losses, iteration)
        if USE_DIS_GLOBAL:
            writer.add_scalar('Loss_DISC/P3_dt', loss_dict['loss_adv_P3_dt'],
                              iteration)
            writer.add_scalar('Loss_DISC/P4_dt', loss_dict['loss_adv_P4_dt'],
                              iteration)
            writer.add_scalar('Loss_DISC/P5_dt', loss_dict['loss_adv_P5_dt'],
                              iteration)
            writer.add_scalar('Loss_DISC/P6_dt', loss_dict['loss_adv_P6_dt'],
                              iteration)
            writer.add_scalar('Loss_DISC/P7_dt', loss_dict['loss_adv_P7_dt'],
                              iteration)

        if USE_DIS_CENTER_AWARE:
            writer.add_scalar('Loss_DISC/P3_CA_dt',
                              loss_dict['loss_adv_P3_CA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P4_CA_dt',
                              loss_dict['loss_adv_P4_CA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P5_CA_dt',
                              loss_dict['loss_adv_P5_CA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P6_CA_dt',
                              loss_dict['loss_adv_P6_CA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P7_CA_dt',
                              loss_dict['loss_adv_P7_CA_dt'], iteration)

        if USE_DIS_CONDITIONAL:
            writer.add_scalar('Loss_DISC/P3_Cond_dt',
                              loss_dict['loss_adv_P3_Cond_dt'], iteration)
            writer.add_scalar('Loss_DISC/P4_Cond_dt',
                              loss_dict['loss_adv_P4_Cond_dt'], iteration)
            writer.add_scalar('Loss_DISC/P5_Cond_dt',
                              loss_dict['loss_adv_P5_Cond_dt'], iteration)
            writer.add_scalar('Loss_DISC/P6_Cond_dt',
                              loss_dict['loss_adv_P6_Cond_dt'], iteration)
            writer.add_scalar('Loss_DISC/P7_Cond_dt',
                              loss_dict['loss_adv_P7_Cond_dt'], iteration)
            for layer in used_feature_layers:
                for i in range(3):
                    writer.add_scalar(
                        'Stat/{}/Target_{}_left'.format(layer, i),
                        stat['%s_target_left' % layer][i], iteration)
                    writer.add_scalar('Stat/{}/Target_{}_top'.format(layer, i),
                                      stat['%s_target_top' % layer][i],
                                      iteration)

        if USE_DIS_HEAD:
            writer.add_scalar('Loss_DISC/P3_HA_dt',
                              loss_dict['loss_adv_P3_HA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P4_HA_dt',
                              loss_dict['loss_adv_P4_HA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P5_HA_dt',
                              loss_dict['loss_adv_P5_HA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P6_HA_dt',
                              loss_dict['loss_adv_P6_HA_dt'], iteration)
            writer.add_scalar('Loss_DISC/P7_HA_dt',
                              loss_dict['loss_adv_P7_HA_dt'], iteration)

        # del "zero" (useless after backward)
        del loss_dict['zero']

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss_dt=losses_reduced, **loss_dict_reduced)

        # # saved GRL gradient
        # grad_list = []
        # for layer in used_feature_layers:
        #     def save_grl_grad(grad):
        #         grad_list.append(grad)
        #     features_t[layer].register_hook(save_grl_grad)
        #
        # losses.backward()
        #
        # ##########################################################################
        # ##########################################################################
        # ##########################################################################
        # max_norm = 5
        # for k in model:
        #     torch.nn.utils.clip_grad_norm_(model[k].parameters(), max_norm)
        #
        # # optimizer.step()
        # for k in optimizer:
        #     optimizer[k].step()
        #
        # if pytorch_1_1_0_or_later:
        #     # scheduler.step()
        #     for k in scheduler:
        #         scheduler[k].step()

        # End of training
        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        sample_layer = used_feature_layers[
            0]  # sample any one of used feature layer
        if USE_DIS_GLOBAL:
            if seperate_dis:
                sample_optimizer = optimizer["dis_%s" % sample_layer]
            else:
                sample_optimizer = optimizer["dis_P7"]
        if USE_DIS_CENTER_AWARE:
            if seperate_dis:
                sample_optimizer = optimizer["dis_%s_CA" % sample_layer]
            else:
                sample_optimizer = optimizer["dis_P7_CA"]
        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr_backbone: {lr_backbone:.6f}",
                    "lr_fcos: {lr_fcos:.6f}",
                    "lr_dis: {lr_dis:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr_backbone=optimizer["backbone"].param_groups[0]["lr"],
                    lr_fcos=optimizer["fcos"].param_groups[0]["lr"],
                    lr_dis=sample_optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_final", **arguments)
            results = run_test(cfg, model, distributed)
            for ap_key in results[0][0].results['bbox'].keys():
                writer.add_scalar('mAP_val/{}'.format(ap_key),
                                  results[0][0].results['bbox'][ap_key],
                                  iteration)
            map50 = results[0][0].results['bbox']['AP50']
            if map50 > best_map50:
                checkpointer.save("model_best", **arguments)
                best_map50 = map50
            for k in model:
                model[k].train()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Beispiel #13
0
def do_train_base(model, data_loader, optimizer, scheduler, checkpointer,
                  device, checkpoint_period, arguments, cfg, run_test,
                  distributed, writer):
    # Start training
    logger = logging.getLogger("fcos_core.trainer")
    logger.info("Start training")

    # model.train()
    for k in model:
        model[k].train()

    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()
    pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later()
    best_map50 = 0.0
    for iteration, (images_s, targets_s,
                    _) in enumerate(data_loader, start_iter):

        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step()
        if not pytorch_1_1_0_or_later:
            # scheduler.step()
            for k in scheduler:
                scheduler[k].step()

        images_s = images_s.to(device)
        targets_s = [target_s.to(device) for target_s in targets_s]

        # optimizer.zero_grad()
        for k in optimizer:
            optimizer[k].zero_grad()

        ##########################################################################
        #################### (1): train G #####################
        ##########################################################################

        loss_dict, features_s, score_maps_s = foward_detector(
            model, images_s, targets=targets_s, return_maps=True)

        # rename loss to indicate domain
        # loss_dict = {k + "_gs": loss_dict[k] for k in loss_dict}
        loss_dict = {k + "_gs": loss_dict[k] for k in loss_dict if 'loss' in k}

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_loss_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        meters.update(loss_gs=losses_reduced, **loss_dict_reduced)

        writer.add_scalar('Loss_FCOS/gs', losses, iteration)
        writer.add_scalar('Loss_FCOS/cls_gs', loss_dict['loss_cls_gs'],
                          iteration)
        writer.add_scalar('Loss_FCOS/reg_gs', loss_dict['loss_reg_gs'],
                          iteration)
        writer.add_scalar('Loss_FCOS/centerness_gs',
                          loss_dict['loss_centerness_gs'], iteration)

        losses.backward(retain_graph=True)
        del loss_dict, losses

        ##########################################################################
        ##########################################################################
        ##########################################################################
        # max_norm = 5
        # for k in model:
        #     torch.nn.utils.clip_grad_norm_(model[k].parameters(), max_norm)

        # optimizer.step()
        for k in optimizer:
            optimizer[k].step()

        if pytorch_1_1_0_or_later:
            # scheduler.step()
            for k in scheduler:
                scheduler[k].step()

        # End of training
        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr_backbone: {lr_backbone:.6f}",
                    "lr_fcos: {lr_fcos:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr_backbone=optimizer["backbone"].param_groups[0]["lr"],
                    lr_fcos=optimizer["fcos"].param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))
        if iteration % checkpoint_period == 0:
            checkpointer.save("model_final", **arguments)
            results = run_test(cfg, model, distributed)
            for ap_key in results[0][0].results['bbox'].keys():
                writer.add_scalar('mAP_val/{}'.format(ap_key),
                                  results[0][0].results['bbox'][ap_key],
                                  iteration)
            map50 = results[0][0].results['bbox']['AP50']
            if map50 > best_map50:
                checkpointer.save("model_best", **arguments)
                best_map50 = map50
            for k in model:
                model[k].train()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
Beispiel #14
0
def train(cfg, local_rank, distributed, iter_clear, ignore_head):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    if iter_clear:
        load_opt = False
        load_sch = False
    else:
        load_opt = True
        load_sch = True
    if ignore_head:
        load_body = True
        load_fpn = True
        load_head = False
    else:
        load_body = True
        load_fpn = True
        load_head = True

    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT,
                                              load_opt=load_opt,
                                              load_sch=load_sch,
                                              load_body=load_body,
                                              load_fpn=load_fpn,
                                              load_head=load_head)

    arguments.update(extra_checkpoint_data)

    if iter_clear:
        arguments["iteration"] = 0

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
Beispiel #15
0
def train(cfg, local_rank, distributed):
    writer = SummaryWriter('runs/{}'.format(cfg.OUTPUT_DIR))
    ##########################################################################
    ############################# Initial Model ##############################
    ##########################################################################
    model = {}
    device = torch.device(cfg.MODEL.DEVICE)

    backbone = build_backbone(cfg).to(device)
    fcos = build_rpn(cfg, backbone.out_channels).to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        backbone = torch.nn.SyncBatchNorm.convert_sync_batchnorm(backbone)
        fcos = torch.nn.SyncBatchNorm.convert_sync_batchnorm(fcos)

    ##########################################################################
    #################### Initial Optimizer and Scheduler #####################
    ##########################################################################
    optimizer = {}
    optimizer["backbone"] = make_optimizer(cfg, backbone, name='backbone')
    optimizer["fcos"] = make_optimizer(cfg, fcos, name='fcos')
    scheduler = {}
    scheduler["backbone"] = make_lr_scheduler(cfg, optimizer["backbone"], name='backbone')
    scheduler["fcos"] = make_lr_scheduler(cfg, optimizer["fcos"], name='fcos')

    if distributed:
        backbone = torch.nn.parallel.DistributedDataParallel(
            backbone, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False
        )
        fcos = torch.nn.parallel.DistributedDataParallel(
            fcos, device_ids=[local_rank], output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False
        )

    ########################### Save Model to Dict ###########################
    ##########################################################################
    model["backbone"] = backbone
    model["fcos"] = fcos

    ##########################################################################
    ################################ Training ################################
    ##########################################################################
    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(
        cfg, model, optimizer, scheduler, output_dir, save_to_disk
    )
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader_source(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train_base(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        cfg,
        run_test,
        distributed,
        writer
    )

    return model