示例#1
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.DEVICE)
    model.to(device)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_scheduler(cfg, optimizer)

    use_mixed_precision = cfg.DTYPE == 'float16'
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            broadcast_buffers=False)

    arguments = {}
    arguments['iteration'] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0

    checkpointer = Checkpointer(model, optimizer, scheduler, output_dir,
                                save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(cfg,
                                   is_train=True,
                                   is_distributed=distributed,
                                   start_iter=arguments['iteration'])

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    do_train(
        model,
        data_loader,
        optimizer,
        scheduler,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
    )

    return model
示例#2
0
def start_train(cfg):
    logger = logging.getLogger('SSD.trainer')
    model = SSDDetector(cfg)
    model = torch_utils.to_cuda(model)

    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=cfg.SOLVER.LR,
        momentum=cfg.SOLVER.MOMENTUM,
        weight_decay=cfg.SOLVER.WEIGHT_DECAY,
        nesterov=True,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=int(cfg.SOLVER.MAX_ITER / 1000), eta_min=0)
    arguments = {"iteration": 0}
    save_to_disk = True
    checkpointer = CheckPointer(
        model,
        optimizer,
        cfg.OUTPUT_DIR,
        save_to_disk,
        logger,
    )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    model = do_train(cfg, model, train_loader, optimizer, checkpointer,
                     arguments, scheduler)
    return model
示例#3
0
def start_train(cfg):
    logger = logging.getLogger('SSD.trainer')
    model = SSDDetector(cfg)
    model = torch_utils.to_cuda(model)

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=cfg.SOLVER.LR,
        momentum=cfg.SOLVER.MOMENTUM,
        weight_decay=cfg.SOLVER.WEIGHT_DECAY
    )


    arguments = {"iteration": 0}
    save_to_disk = True
    checkpointer = CheckPointer(
        model, optimizer, cfg.OUTPUT_DIR, save_to_disk, logger,
        )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER
    train_loader = make_data_loader(cfg, is_train=True, max_iter=max_iter, start_iter=arguments['iteration'])

    model = do_train(
        cfg, model, train_loader, optimizer,
        checkpointer, arguments)
    return model
示例#4
0
def start_train(cfg):
    logger = logging.getLogger('SSD.trainer')
    model = SSDDetector(cfg)
    model = torch_utils.to_cuda(model)

    lr = cfg.SOLVER.LR
    optimizer = make_optimizer(cfg, model, lr)

    milestones = [step for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}
    save_to_disk = True
    checkpointer = CheckPointer(cfg, model, optimizer, scheduler,
                                cfg.OUTPUT_DIR, save_to_disk, logger)
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    model = do_train(cfg, model, train_loader, optimizer, scheduler,
                     checkpointer, arguments)
    return model
示例#5
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)

    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}
    save_to_disk = dist_util.get_rank() == 0
    checkpointer = CheckPointer(model, optimizer, scheduler, cfg.OUTPUT_DIR,
                                save_to_disk, logger)
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    distributed=args.distributed,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    model = do_train(cfg, model, train_loader, optimizer, scheduler,
                     checkpointer, device, arguments, args)
    return model
示例#6
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)

    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}
    save_to_disk = dist_util.get_rank() == 0
    checkpointer = CheckPointer(model, optimizer, scheduler, cfg.OUTPUT_DIR,
                                save_to_disk, logger)
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    distributed=args.distributed,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    # macs, params = profile(model, inputs=(input, ))
    #
    # macs, params = clever_format([flops, params], "%.3f")

    # net = model.to()
    # with torch.cuda.device(0):

    # net = model.to(device)
    # macs, params = get_model_complexity_info(net, (3, 512, 512), as_strings=True,
    #                                        print_per_layer_stat=True, verbose=True)
    # print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    # print('{:<30}  {:<8}'.format('Number of parameters: ', params))

    n_params = sum(p.numel() for name, p in model.named_parameters()
                   if p.requires_grad)
    print(n_params)
    #
    # model = net
    # inputs = torch.randn(1, 3, 300, 300) #8618 305
    # inputs = torch.randn(1, 3, 300, 300)

    # macs = profile_macs(model, inputs)
    # print(macs)

    model = do_train(cfg, model, train_loader, optimizer, scheduler,
                     checkpointer, device, arguments, args)
    return model
示例#7
0
def start_train(cfg, visualize_example=False):
    logger = logging.getLogger('SSD.trainer')
    model = SSDDetector(cfg)
    print(model)
    model = torch_utils.to_cuda(model)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=cfg.SOLVER.LR,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    """
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=cfg.SOLVER.LR,
        weight_decay=cfg.SOLVER.WEIGHT_DECAY
    )
    """
    """
        lr_scheduler = torch.optim.lr_scheduler.CyclicLR(
        optimizer= optimizer,
        base_lr= cfg.SOLVER.LR /10,
        max_lr=0.05,
        step_size_up=8000,
        mode='triangular2'
        )

    """

    arguments = {"iteration": 0}
    save_to_disk = True
    checkpointer = CheckPointer(
        model,
        optimizer,
        cfg.OUTPUT_DIR,
        save_to_disk,
        logger,
    )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    model = do_train(cfg,
                     model,
                     train_loader,
                     optimizer,
                     checkpointer,
                     arguments,
                     visualize_example,
                     lr_scheduler=None)
    return model
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_mobilev1_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        model.load(args.resume)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    # -----------------------------------------------------------------------------
    # Criterion
    # -----------------------------------------------------------------------------
    criterion = MultiBoxLoss(neg_pos_ratio=cfg.MODEL.NEG_POS_RATIO)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE, cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE, cfg.MODEL.THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN, transform=train_transform, target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset, num_workers=4, batch_sampler=batch_sampler)

    return do_train(cfg, model, train_loader, optimizer, scheduler, criterion, device, args)
示例#9
0
def start_train(cfg):
    logger = logging.getLogger('SSD.trainer')
    model = SSDDetector(cfg)
    model = torch_utils.to_cuda(model)

    if cfg.SOLVER.TYPE == "adam":
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=cfg.SOLVER.LR,
            weight_decay=cfg.SOLVER.WEIGHT_DECAY,
        )
    elif cfg.SOLVER.TYPE == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cfg.SOLVER.LR,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY,
                                    momentum=cfg.SOLVER.MOMENTUM)
    else:
        # Default to Adam if incorrect solver
        print("WARNING: Incorrect solver type, defaulting to Adam")
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=cfg.SOLVER.LR,
            weight_decay=cfg.SOLVER.WEIGHT_DECAY,
        )

    scheduler = LinearMultiStepWarmUp(cfg, optimizer)

    arguments = {"iteration": 0}
    save_to_disk = True
    checkpointer = CheckPointer(
        model,
        optimizer,
        cfg.OUTPUT_DIR,
        save_to_disk,
        logger,
    )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    model = do_train(cfg, model, train_loader, optimizer, checkpointer,
                     arguments, scheduler)
    return model
示例#10
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)

    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}
    save_to_disk = dist_util.get_rank() == 0
    checkpointer = CheckPointer(model, optimizer, scheduler, cfg.OUTPUT_DIR,
                                save_to_disk, logger)
    extra_checkpoint_data = checkpointer.load(args.ckpt)
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    distributed=args.distributed,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    logging.info('==>Start statistic')
    do_run(cfg, model, distributed=args.distributed)
    logging.info('==>End statistic')

    for ops in model.modules():
        if isinstance(ops, torch.nn.ReLU):
            ops.collectStats = False

            #            ops.c.data = ops.running_mean + (ops.running_b * laplace[args.actBitwidth])
            ops.c.data = ops.running_mean + (3 * ops.running_std)
            ops.quant = True
    torch.cuda.empty_cache()
    model = do_train(cfg, model, train_loader, optimizer, scheduler,
                     checkpointer, device, arguments, args)
    return model
示例#11
0
def train(cfg: CfgNode,
          args: Namespace,
          output_dir: Path,
          model_manager: Dict[str, Any],
          freeze_non_sigma: bool = False):
    logger = logging.getLogger('SSD.trainer')
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)

    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}
    save_to_disk = dist_util.get_rank() == 0
    checkpointer = CheckPointer(model, optimizer, scheduler, cfg.OUTPUT_DIR,
                                save_to_disk, logger)
    resume_from = checkpointer.get_best_from_experiment_dir(cfg)
    extra_checkpoint_data = checkpointer.load(f=resume_from)
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    distributed=args.distributed,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    # Weight freezing test:
    # print_model(model)
    # freeze_weights(model)
    print_model(model)

    model = do_train(cfg, model, train_loader, optimizer, scheduler,
                     checkpointer, device, arguments, args, output_dir,
                     model_manager)
    return model
示例#12
0
def start_train(cfg):
    logger = logging.getLogger('SSD.trainer')
    model = SSDDetector(cfg)
    model = torch_utils.to_cuda(model)

    # SGD
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=cfg.SOLVER.LR,
    #     momentum=cfg.SOLVER.MOMENTUM,
    #     weight_decay=cfg.SOLVER.WEIGHT_DECAY
    # )

    # Adam
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.SOLVER.LR,
                                 weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                     milestones=[6000, 10000],
                                                     gamma=cfg.SOLVER.GAMMA)

    arguments = {"iteration": 0}
    save_to_disk = True
    checkpointer = CheckPointer(
        model,
        optimizer,
        cfg.OUTPUT_DIR,
        save_to_disk,
        logger,
    )
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    model = do_train(cfg, model, train_loader, optimizer, checkpointer,
                     arguments, scheduler)
    return model
示例#13
0
def train(cfg, args):
    # 工厂模式,加载日志文件设置,这里暂时不同管
    logger = logging.getLogger('SSD.trainer')
    # 建立目标检测模型
    model = build_detection_model(cfg)
    # 设置Device并且把模型部署到设备上
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    # 设置学习率、优化器还有学习率变化步长,可以理解为模拟退火这种,前面的步长比较大,后面的步长比较小
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)

    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}
    save_to_disk = dist_util.get_rank() == 0
    # **** 这里应该是从断点开始对模型进行训练 ****
    checkpointer = CheckPointer(model, optimizer, scheduler, cfg.OUTPUT_DIR,
                                save_to_disk, logger)
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    # Important 通过torch的形式去加载数据集
    # 关键在于如何加载数据集,模型的构建过程可以简单地看成是黑盒
    max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    distributed=args.distributed,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])

    # 正式开始训练
    model = do_train(cfg, model, train_loader, optimizer, scheduler,
                     checkpointer, device, arguments, args)
    return model
示例#14
0
文件: train.py 项目: touchylk/lwk_SSD
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    model = build_detection_model(cfg)  # 建立模型
    device = torch.device(cfg.MODEL.DEVICE)  # 看cfg怎么组织的,把文件和args剥离开
    model.to(device)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
        # model = nn.DataParallel(model)

    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = make_optimizer(cfg, model, lr)  # 建立优化器

    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}
    save_to_disk = dist_util.get_rank() == 0
    checkpointer = CheckPointer(model,
                                optimizer,
                                scheduler,
                                save_dir=cfg.OUTPUT_DIR,
                                save_to_disk=save_to_disk,
                                logger=logger)
    # 建立模型存储载入类,给save_dir赋值表示
    extra_checkpoint_data = checkpointer.load(f='', use_latest=False)  # 载入模型
    arguments.update(extra_checkpoint_data)

    max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus
    train_loader = make_data_loader(cfg,
                                    is_train=True,
                                    distributed=args.distributed,
                                    max_iter=max_iter,
                                    start_iter=arguments['iteration'])  # 建立数据库

    print("dataloader: ", train_loader.batch_size)
    # exit(1232)
    model = do_train(cfg, model, train_loader, optimizer, scheduler,
                     checkpointer, device, arguments, args)  # 训练
    return model
示例#15
0
文件: helpers.py 项目: mDuval1/SSD
 def fit(self, train_loader):
     self.model = do_train(self.cfg, self.model, train_loader,
                           self.optimizer, self.scheduler,
                           self.checkpointer, self.device, self.arguments,
                           self.args)
     return self.model
示例#16
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Load weights or restore checkpoint
    # -----------------------------------------------------------------------------
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        restore_training_checkpoint(logger,
                                    model,
                                    args.resume,
                                    optimizer=optimizer,
                                    scheduler=scheduler)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.USE_AMP
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN,
                                        cfg.INPUT.PIXEL_STD)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)

    if cfg.DATASETS.DG:
        if args.eval_mode == "val":
            dslist, val_set_dict = _create_dg_datasets(args, cfg, logger,
                                                       target_transform,
                                                       train_transform)
        else:
            dslist = _create_dg_datasets(args, cfg, logger, target_transform,
                                         train_transform)

        logger.info("Sizes of sources datasets:")
        for k, v in dslist.items():
            logger.info("{} size: {}".format(k, len(v)))

        dataloaders = []
        for name, train_dataset in dslist.items():
            sampler = torch.utils.data.RandomSampler(train_dataset)
            batch_sampler = torch.utils.data.sampler.BatchSampler(
                sampler=sampler,
                batch_size=cfg.SOLVER.BATCH_SIZE,
                drop_last=True)

            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER)

            if cfg.MODEL.SELF_SUPERVISED:
                ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
                train_loader = DataLoader(ss_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            else:
                train_loader = DataLoader(train_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            dataloaders.append(train_loader)

        if args.eval_mode == "val":
            if args.return_best:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args, val_set_dict)
            else:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args)
        else:
            return do_train(cfg, model, dataloaders, optimizer, scheduler,
                            device, args)

    # No DG:
    if args.eval_mode == "val":
        train_dataset, val_dataset = build_dataset(
            dataset_list=cfg.DATASETS.TRAIN,
            transform=train_transform,
            target_transform=target_transform,
            split=True)
    else:
        train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                      transform=train_transform,
                                      target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)

    if cfg.MODEL.SELF_SUPERVISED:
        ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
        train_loader = DataLoader(ss_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)
    else:
        train_loader = DataLoader(train_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)

    if args.eval_mode == "val":
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args, {"validation_split": val_dataset})
    else:
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args)
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        iteration = checkpoint['iteration']
        print('iteration:', iteration)
    elif args.vgg:
        iteration = 0
        logger.info("Init from backbone net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    else:
        iteration = 0
        logger.info("all init from kaiming init")
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    #optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    print('cfg.SOLVER.WEIGHT_DECAY:', cfg.SOLVER.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # ------------------------1-----------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    #对原始图像进行数据增强
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.IOU_THRESHOLD, cfg.MODEL.PRIORS.DISTANCE_THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                  transform=train_transform,
                                  target_transform=target_transform,
                                  args=args)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    sampler = torch.utils.data.RandomSampler(train_dataset)
    # sampler = torch.utils.data.SequentialSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset,
                              num_workers=4,
                              batch_sampler=batch_sampler,
                              pin_memory=True)

    return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                    args, iteration)