示例#1
0
def make_data_loader(cfg,
                     is_train=True,
                     augment=False,
                     max_iter=None,
                     start_iter=0):
    target_transform = build_target_transform(cfg) if is_train else None
    dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST

    if augment:
        print('Augmenting training data')
        transforms = build_transforms(cfg, is_train=is_train, augment=augment)
        datasets = build_dataset(cfg.DATASET_DIR,
                                 dataset_list,
                                 transform=transforms,
                                 target_transform=target_transform,
                                 is_train=is_train)
    else:
        train_transform = build_transforms(cfg, is_train=is_train)
        datasets = build_dataset(cfg.DATASET_DIR,
                                 dataset_list,
                                 transform=train_transform,
                                 target_transform=target_transform,
                                 is_train=is_train)

    shuffle = is_train

    data_loaders = []

    for dataset in datasets:
        if shuffle:
            sampler = torch.utils.data.RandomSampler(dataset)
        else:
            sampler = torch.utils.data.sampler.SequentialSampler(dataset)

        batch_size = cfg.SOLVER.BATCH_SIZE if is_train else cfg.TEST.BATCH_SIZE
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler=sampler, batch_size=batch_size, drop_last=is_train)
        if max_iter is not None:
            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=max_iter, start_iter=start_iter)

        data_loader = DataLoader(dataset,
                                 num_workers=cfg.DATA_LOADER.NUM_WORKERS,
                                 batch_sampler=batch_sampler,
                                 pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
                                 collate_fn=BatchCollator(is_train))
        data_loaders.append(data_loader)

    if is_train:
        # during training, a single (possibly concatenated) data_loader is returned
        assert len(data_loaders) == 1
        return data_loaders[0]
    return data_loaders
示例#2
0
def setup_self_ade(cfg, args):
    logger = logging.getLogger("self_ade.setup")
    logger.info("Starting self_ade setup")

    # build model from config
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN,
                                        cfg.INPUT.PIXEL_STD)

    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)

    test_dataset = build_dataset(dataset_list=cfg.DATASETS.TEST,
                                 is_test=True)[0]
    self_ade_dataset = build_dataset(dataset_list=cfg.DATASETS.TEST,
                                     transform=train_transform,
                                     target_transform=target_transform)
    ss_dataset = SelfSupervisedDataset(self_ade_dataset, cfg)

    test_sampler = SequentialSampler(test_dataset)
    os_sampler = OneSampleBatchSampler(test_sampler, cfg.SOLVER.BATCH_SIZE,
                                       args.self_ade_iterations)

    self_ade_dataloader = DataLoader(ss_dataset,
                                     batch_sampler=os_sampler,
                                     num_workers=args.num_workers)

    effective_lr = args.learning_rate * args.self_ade_weight

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=effective_lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.USE_AMP
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    execute_self_ade(cfg, args, test_dataset, self_ade_dataloader, model,
                     optimizer)
示例#3
0
文件: build.py 项目: naviocean/SSD
def make_data_loader(cfg, is_train=True, distributed=False, max_iter=None, start_iter=0):
    train_transform = TrainAugmentation(
        cfg) if is_train else TestTransform(cfg)
    target_transform = SSDTargetTransform(cfg) if is_train else None
    datasets = build_dataset(cfg, transform=train_transform,
                             target_transform=target_transform, is_train=is_train)
    shuffle = is_train

    data_loaders = []

    for dataset in datasets:
        if distributed:
            sampler = samplers.DistributedSampler(dataset, shuffle=shuffle)
        elif shuffle:
            sampler = torch.utils.data.RandomSampler(dataset)
        else:
            sampler = torch.utils.data.sampler.SequentialSampler(dataset)

        batch_size = cfg.SOLVER.BATCH_SIZE if is_train else cfg.TEST.BATCH_SIZE
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler=sampler, batch_size=batch_size, drop_last=False)
        if max_iter is not None:
            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=max_iter, start_iter=start_iter)

        data_loader = DataLoader(dataset, num_workers=cfg.DATA_LOADER.NUM_WORKERS, batch_sampler=batch_sampler,
                                 pin_memory=cfg.DATA_LOADER.PIN_MEMORY, collate_fn=BatchCollator(is_train))
        data_loaders.append(data_loader)

    if is_train:
        # during training, a single (possibly concatenated) data_loader is returned
        assert len(data_loaders) == 1
        return data_loaders[0]
    return data_loaders
示例#4
0
def _create_dg_datasets(args, cfg, logger, target_transform, train_transform):
    dslist = {}
    if args.eval_mode == "val":
        val_set_dict = {}
        default_domain_dataset, default_domain_val_set = build_dataset(
            dataset_list=cfg.DATASETS.DG_SETTINGS.DEFAULT_DOMAIN,
            transform=train_transform,
            target_transform=target_transform,
            split=True)
        val_set_dict["Default domain"] = default_domain_val_set
        logger.info(
            "Default domain: train split has {} elements, test split has {} elements"
            .format(len(default_domain_dataset), len(default_domain_val_set)))
    else:
        default_domain_dataset = build_dataset(
            dataset_list=cfg.DATASETS.DG_SETTINGS.DEFAULT_DOMAIN,
            transform=train_transform,
            target_transform=target_transform)
    dslist["Default domain"] = default_domain_dataset
    for element in cfg.DATASETS.DG_SETTINGS.SOURCE_DOMAINS:
        if not isinstance(element, tuple):
            sets = (element, )
        else:
            sets = element

        if args.eval_mode == "val":
            ds, val_set = build_dataset(dataset_list=sets,
                                        transform=train_transform,
                                        target_transform=target_transform,
                                        split=True)
            val_set_dict[element] = val_set
            logger.info(
                "Domain {}: train split has {} elements, test split has {} elements"
                .format(str(element), len(ds), len(val_set)))
        else:
            ds = build_dataset(dataset_list=sets,
                               transform=train_transform,
                               target_transform=target_transform)

        dslist[element] = ds

    if args.eval_mode == "val":
        return dslist, val_set_dict
    else:
        return dslist
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_mobilev1_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        model.load(args.resume)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    # -----------------------------------------------------------------------------
    # Criterion
    # -----------------------------------------------------------------------------
    criterion = MultiBoxLoss(neg_pos_ratio=cfg.MODEL.NEG_POS_RATIO)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE, cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE, cfg.MODEL.THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN, transform=train_transform, target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset, num_workers=4, batch_sampler=batch_sampler)

    return do_train(cfg, model, train_loader, optimizer, scheduler, criterion, device, args)
示例#6
0
def make_data_loader(cfg,
                     is_train=True,
                     distributed=False,
                     max_iter=None,
                     start_iter=0):
    train_transform = build_transforms(cfg, is_train=is_train)
    target_transform = build_target_transform(cfg) if is_train else None
    dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST
    print('数据集....')
    print(dataset_list)
    # 1. 首先是建立数据集dataset

    datasets = build_dataset(dataset_list,
                             transform=train_transform,
                             target_transform=target_transform,
                             is_train=is_train)

    shuffle = is_train or distributed

    data_loaders = []

    for dataset in datasets:
        if distributed:
            sampler = samplers.DistributedSampler(dataset, shuffle=shuffle)
        elif shuffle:
            sampler = torch.utils.data.RandomSampler(dataset)
        else:
            sampler = torch.utils.data.sampler.SequentialSampler(dataset)

        batch_size = cfg.SOLVER.BATCH_SIZE if is_train else cfg.TEST.BATCH_SIZE
        # 这里的batchsize是32,给力!
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler=sampler, batch_size=batch_size, drop_last=False)
        if max_iter is not None:
            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=max_iter, start_iter=start_iter)

        # 2. 然后建立的是数据加载器,这里指明了使用的CPU的线程和batch_sampler
        data_loader = DataLoader(dataset,
                                 num_workers=cfg.DATA_LOADER.NUM_WORKERS,
                                 batch_sampler=batch_sampler,
                                 pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
                                 collate_fn=BatchCollator(is_train))
        data_loaders.append(data_loader)

    if is_train:
        # during training, a single (possibly concatenated) data_loader is returned
        assert len(data_loaders) == 1
        return data_loaders[0]
    return data_loaders
示例#7
0
def _create_val_datasets(args, cfg, logger):
    dslist = {}
    val_set_dict = {}

    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE, cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
    target_transform = MatchPrior(PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
                                  cfg.MODEL.THRESHOLD)

    default_domain_dataset, default_domain_val_set = build_dataset(
        dataset_list=cfg.DATASETS.DG_SETTINGS.DEFAULT_DOMAIN,
        transform=train_transform, target_transform=target_transform, split=True)
    val_set_dict["Default domain"] = default_domain_val_set
    logger.info("Default domain: train split has {} elements, test split has {} elements".format(
        len(default_domain_dataset), len(default_domain_val_set)))

    dslist["Default domain"] = default_domain_dataset
    for element in cfg.DATASETS.DG_SETTINGS.SOURCE_DOMAINS:
        if not isinstance(element, tuple):
            sets = (element,)
        else:
            sets = element

        if args.eval_mode == "val":
            ds, val_set = build_dataset(dataset_list=sets, transform=train_transform,
                                        target_transform=target_transform, split=True)
            val_set_dict[element] = val_set
            logger.info(
                "Domain {}: train split has {} elements, test split has {} elements".format(str(element), len(ds),
                                                                                            len(val_set)))
        else:
            ds = build_dataset(dataset_list=sets, transform=train_transform,
                               target_transform=target_transform)

        dslist[element] = ds

    return val_set_dict
示例#8
0
def do_evaluation(cfg, model, output_dir, distributed):
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module
    assert isinstance(model, SSD), 'Wrong module.'
    test_datasets = build_dataset(dataset_list=cfg.DATASETS.TEST, is_test=True)
    device = torch.device(cfg.MODEL.DEVICE)
    model.eval()
    if not model.is_test:
        model.is_test = True
    predictor = Predictor(cfg=cfg,
                          model=model,
                          iou_threshold=cfg.TEST.NMS_THRESHOLD,
                          score_threshold=cfg.TEST.CONFIDENCE_THRESHOLD,
                          device=device)

    cpu_device = torch.device("cpu")
    logger = logging.getLogger("SSD.inference")
    for dataset_name, test_dataset in zip(cfg.DATASETS.TEST, test_datasets):
        logger.info("Test dataset {} size: {}".format(dataset_name,
                                                      len(test_dataset)))
        indices = list(range(len(test_dataset)))
        if distributed:
            indices = indices[distributed_util.get_rank()::distributed_util.
                              get_world_size()]

        # show progress bar only on main process.
        progress_bar = tqdm if distributed_util.is_main_process() else iter
        logger.info('Progress on {} 0:'.format(cfg.MODEL.DEVICE.upper()))
        predictions = {}
        for i in progress_bar(indices):
            image = test_dataset.get_image(i)
            output = predictor.predict(image)
            boxes, labels, scores = [o.to(cpu_device).numpy() for o in output]
            predictions[i] = (boxes, labels, scores)
        distributed_util.synchronize()
        predictions = _accumulate_predictions_from_multiple_gpus(predictions)
        if not distributed_util.is_main_process():
            return

        final_output_dir = os.path.join(output_dir, dataset_name)
        if not os.path.exists(final_output_dir):
            os.makedirs(final_output_dir)
        torch.save(predictions,
                   os.path.join(final_output_dir, 'predictions.pth'))
        evaluate(dataset=test_dataset,
                 predictions=predictions,
                 output_dir=final_output_dir)
示例#9
0
def do_evaluation(cfg, model, output_dir, distributed, datasets_dict=None):
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module
    assert isinstance(model, SSD), 'Wrong module.'
    if datasets_dict is not None:
        if cfg.TEST.MODE == "joint":
            test_datasets = DetectionConcatDataset(datasets_dict.values())
        else:
            test_datasets = datasets_dict.values()
        joint_dataset_name = "Concatenation of validation splits"
        datasets_names = datasets_dict.keys()
    else:
        test_datasets = build_dataset(dataset_list=cfg.DATASETS.TEST,
                                      is_test=True)
        datasets_names = cfg.DATASETS.TEST
        joint_dataset_name = "Concatenation of test sets"

    device = torch.device(cfg.MODEL.DEVICE)
    model.eval()
    predictor = Predictor(cfg=cfg,
                          iou_threshold=cfg.TEST.NMS_THRESHOLD,
                          score_threshold=cfg.TEST.CONFIDENCE_THRESHOLD,
                          device=device,
                          model=model)
    # evaluate all test datasets.
    logger = logging.getLogger("SSD.inference")

    metrics = {}

    if cfg.TEST.MODE == "split":
        logger.info('Will evaluate {} dataset(s):'.format(len(test_datasets)))
        for dataset_name, test_dataset in zip(datasets_names, test_datasets):
            metric = _evaluation(cfg, dataset_name, test_dataset, predictor,
                                 distributed, output_dir)
            metrics[dataset_name] = metric
            distributed_util.synchronize()
    else:
        logger.info('Will evaluate {} image(s):'.format(len(test_datasets)))
        metric = _evaluation(cfg, joint_dataset_name, test_datasets, predictor,
                             distributed, output_dir)
        metrics[joint_dataset_name] = metric
        distributed_util.synchronize()

    return metrics
示例#10
0
def do_evaluation(cfg, model, output_dir, distributed):
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module
    assert isinstance(model, SSD), 'Wrong module.'
    test_datasets = build_dataset(dataset_list=cfg.DATASETS.TEST, is_test=True)
    device = torch.device(cfg.MODEL.DEVICE)
    model.eval()
    predictor = Predictor(cfg=cfg,
                          model=model,
                          iou_threshold=cfg.TEST.NMS_THRESHOLD,
                          score_threshold=cfg.TEST.CONFIDENCE_THRESHOLD,
                          device=device)
    # evaluate all test datasets.
    logger = logging.getLogger("SSD.inference")
    logger.info('Will evaluate {} dataset(s):'.format(len(test_datasets)))
    for dataset_name, test_dataset in zip(cfg.DATASETS.TEST, test_datasets):
        _evaluation(cfg, dataset_name, test_dataset, predictor, distributed,
                    output_dir)
        distributed_util.synchronize()
示例#11
0
def active_train(cfg, args):
    logger = logging.getLogger("SSD.trainer")
    raw_model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    raw_model.to(device)

    lr = cfg.SOLVER.LR * args.num_gpus
    optimizer = make_optimizer(cfg, raw_model, lr)

    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = make_lr_scheduler(cfg, optimizer, milestones)

    arguments = {"iteration": 0}

    checkpointer = None
    save_to_disk = dist_util.get_rank() == 0
    checkpointer = CheckPointer(raw_model, optimizer, scheduler,
                                args.model_dir, save_to_disk, logger)

    max_iter = cfg.SOLVER.MAX_ITER // args.num_gpus

    is_train = True
    train_transform = build_transforms(cfg, is_train=is_train)
    target_transform = build_target_transform(cfg) if is_train else None
    dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST
    datasets = build_dataset(dataset_list,
                             transform=train_transform,
                             target_transform=target_transform,
                             is_train=is_train)

    logger.info(f'Creating query loader...')
    query_loader = QueryLoader(datasets[0], args, cfg)

    logger.info(f'Creating al model...')
    strategy = get_strategy(args.strategy)
    model = ALModel(raw_model, strategy, optimizer, device, scheduler,
                    arguments, args, checkpointer, cfg)

    logger.info(f'Training on initial data with size {args.init_size}...')
    n_bbox = query_loader.len_annotations()
    t1 = time.time()
    model.fit(query_loader.get_labeled_loader())
    init_time = time.time() - t1
    logger.info(f'Scoring after initial training...')
    score = model.score()
    logger.info(f'SCORE : {score:.4f}')

    fields = [
        args.strategy, {}, 0, score, init_time, 0, init_time,
        len(query_loader), n_bbox
    ]
    save_to_csv(args.filename, fields)

    for step in range(args.query_step):
        logger.info(f'STEP NUMBER {step}')
        logger.info('Querying assets to label')
        t1 = time.time()
        query_idx = model.query(
            unlabeled_loader=query_loader.get_unlabeled_loader(),
            cfg=cfg,
            args=args,
            step=step,
            n_instances=args.query_size,
            length_ds=len(datasets[0]))
        logger.info('Adding labeled samples to train dataset')
        query_loader.add_to_labeled(query_idx, step + 1)
        t2 = time.time()
        logger.info('Fitting with new data...')
        model.fit(query_loader.get_labeled_loader())
        total_time = time.time() - t1
        train_time = time.time() - t2
        active_time = total_time - train_time
        logger.info('Scoring model...')
        score = model.score()
        n_bbox = query_loader.len_annotations()
        fields = [
            args.strategy, {}, step + 1, score, train_time, active_time,
            total_time,
            len(query_loader), n_bbox
        ]
        save_to_csv(args.filename, fields)
        logger.info(f'SCORE : {score:.4f}')

    return model.model
示例#12
0
    def _read_image(self, image_id):
        image_file = os.path.join(self.data_dir, "JPEGImages",
                                  "%s.jpg" % image_id)
        image = Image.open(image_file).convert("RGB")
        image = np.array(image)
        return image


if __name__ == '__main__':
    from ssd.config import cfg
    from ssd.data.transforms import build_transforms, build_target_transform
    from ssd.data.datasets import build_dataset

    is_train = True

    train_transform = build_transforms(cfg, is_train=is_train)
    target_transform = build_target_transform(cfg) if is_train else None
    dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST
    datasets = build_dataset(dataset_list,
                             transform=train_transform,
                             target_transform=target_transform,
                             is_train=is_train)

    image, targets, index = datasets[0].__getitem__(200)
    boxes = targets['boxes']
    labels = targets['labels']
    print(image.shape)
    print(boxes.shape)
    print(labels.shape)
    print(index)
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        iteration = checkpoint['iteration']
        print('iteration:', iteration)
    elif args.vgg:
        iteration = 0
        logger.info("Init from backbone net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    else:
        iteration = 0
        logger.info("all init from kaiming init")
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    #optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    print('cfg.SOLVER.WEIGHT_DECAY:', cfg.SOLVER.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # ------------------------1-----------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    #对原始图像进行数据增强
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.IOU_THRESHOLD, cfg.MODEL.PRIORS.DISTANCE_THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                  transform=train_transform,
                                  target_transform=target_transform,
                                  args=args)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    sampler = torch.utils.data.RandomSampler(train_dataset)
    # sampler = torch.utils.data.SequentialSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset,
                              num_workers=4,
                              batch_sampler=batch_sampler,
                              pin_memory=True)

    return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                    args, iteration)
示例#14
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Load weights or restore checkpoint
    # -----------------------------------------------------------------------------
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        restore_training_checkpoint(logger,
                                    model,
                                    args.resume,
                                    optimizer=optimizer,
                                    scheduler=scheduler)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.USE_AMP
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN,
                                        cfg.INPUT.PIXEL_STD)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)

    if cfg.DATASETS.DG:
        if args.eval_mode == "val":
            dslist, val_set_dict = _create_dg_datasets(args, cfg, logger,
                                                       target_transform,
                                                       train_transform)
        else:
            dslist = _create_dg_datasets(args, cfg, logger, target_transform,
                                         train_transform)

        logger.info("Sizes of sources datasets:")
        for k, v in dslist.items():
            logger.info("{} size: {}".format(k, len(v)))

        dataloaders = []
        for name, train_dataset in dslist.items():
            sampler = torch.utils.data.RandomSampler(train_dataset)
            batch_sampler = torch.utils.data.sampler.BatchSampler(
                sampler=sampler,
                batch_size=cfg.SOLVER.BATCH_SIZE,
                drop_last=True)

            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER)

            if cfg.MODEL.SELF_SUPERVISED:
                ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
                train_loader = DataLoader(ss_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            else:
                train_loader = DataLoader(train_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            dataloaders.append(train_loader)

        if args.eval_mode == "val":
            if args.return_best:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args, val_set_dict)
            else:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args)
        else:
            return do_train(cfg, model, dataloaders, optimizer, scheduler,
                            device, args)

    # No DG:
    if args.eval_mode == "val":
        train_dataset, val_dataset = build_dataset(
            dataset_list=cfg.DATASETS.TRAIN,
            transform=train_transform,
            target_transform=target_transform,
            split=True)
    else:
        train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                      transform=train_transform,
                                      target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)

    if cfg.MODEL.SELF_SUPERVISED:
        ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
        train_loader = DataLoader(ss_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)
    else:
        train_loader = DataLoader(train_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)

    if args.eval_mode == "val":
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args, {"validation_split": val_dataset})
    else:
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args)