예제 #1
0
    def __init__(self, root, batch_size, train=True):
        # from torchvision import transforms
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        # from torchvision import datasets
        dataset = datasets.CIFAR10(root,
                                   train=train,
                                   transform=transform,
                                   download=True)
        # 分布式
        sampler = None
        if train and distributed_is_initialized():
            sampler = data.DistributedSampler(dataset)
        super(CIFAR10DataLoader, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=(sampler is None),
            sampler=sampler,
        )
예제 #2
0
 def __init__(self,
              dataset: dat.Dataset,
              num_workers: int = 4,
              chunk_size: int = 64000,
              batch_size: int = 16,
              distributed: bool = False,
              train: bool = True) -> None:
     self.dataset = dataset
     self.train = train
     self.batch_size = batch_size
     self.splitter = ChunkSplitter(chunk_size,
                                   train=train,
                                   hop=chunk_size // 2)
     if distributed:
         self.sampler = dat.DistributedSampler(
             dataset,
             shuffle=train,
             num_replicas=dist.world_size(),
             rank=dist.rank())
     else:
         self.sampler = None
     # just return batch of egs, support multiple workers
     # NOTE: batch_size is not the batch_size of the audio chunk
     self.eg_loader = dat.DataLoader(self.dataset,
                                     batch_size=min(batch_size, 64),
                                     num_workers=num_workers,
                                     sampler=self.sampler,
                                     shuffle=(train
                                              and self.sampler is None),
                                     collate_fn=self._collate)
예제 #3
0
def create_data_loader(
    train_cfg: dict,
    data_dir: pathlib.Path,
    anchors: torch.tensor,
    batch_size: int,
    world_size: int,
    val: bool,
    image_size: int,
    num_classes: int,
    dataset_name: str,
) -> Tuple[data.DataLoader, data.Sampler]:
    """Simple function to create the dataloaders for training and evaluation.

    Args:
        training_cfg: The parameters related to the training regime.
        data_dir: The directory where the images are located.
        anchors: The tensor of anchors in the model.
        batch_size: The loader's batch size.
        world_size: World size is needed to determine if a distributed sampler is needed.
        val: Whether or not this loader is for validation.
        image_size: Size of input images into the model. NOTE: we force square images.

    Returns:
        The dataloader and the loader's sampler. For _training_ we have to set the
        sampler's epoch to reshuffle.
    """

    assert data_dir.is_dir(), data_dir

    meta = pathlib.Path(data_dir / "annotations.json")
    dataset_ = dataset.DetectionDataset(
        data_dir=data_dir / "images",
        metadata_path=meta,
        img_width=image_size[0],
        img_height=image_size[1],
        validation=val,
    )
    print(dataset_)
    # If using distributed training, use a DistributedSampler to load exclusive sets
    # of data per process.
    sampler = None
    if world_size > 1:
        sampler = data.DistributedSampler(dataset_, shuffle=~val)

    if val:
        collate_fn = collate.CollateVal()
    else:
        collate_fn = collate.Collate(num_classes=num_classes,
                                     original_anchors=anchors)

    loader = data.DataLoader(
        dataset_,
        batch_size=batch_size,
        pin_memory=True,
        sampler=sampler,
        collate_fn=collate_fn,
        num_workers=max(torch.multiprocessing.cpu_count() // world_size, 8),
        drop_last=True if val else False,
    )
    return loader, sampler
예제 #4
0
 def __init__(self,
              version,
              phase,
              transform_name,
              factor=0.8,
              category_pool_name=None,
              label_indent=-1,
              lmdb=None,
              distributed=False,
              **kwargs):
     dataset = _MiniImagenetLMDBHorizontal(MINIIMAGENET_DATASET_DIR,
                                           version, phase, factor,
                                           category_pool_name, label_indent,
                                           lmdb)
     transform_fn_x, transform_fn_y = dataset_transforms[
         self.__class__.__name__][transform_name]
     collate_fn = functools.partial(default_collate_fn,
                                    transform_fn_x=transform_fn_x,
                                    transform_fn_y=transform_fn_y)
     if not not distributed:
         if not torch.distributed.is_available():
             raise RuntimeError(
                 "Requires distributed package to be available.")
         if not torch.distributed.is_initialized():
             raise RuntimeError(
                 "Requires distributed progress group initialized first.")
         sampler = torda.DistributedSampler(dataset)
     else:
         sampler = None
     kwargs['collate_fn'] = collate_fn
     kwargs['sampler'] = sampler
     kwargs['shuffle'] = kwargs.get('shuffle', False) and (sampler is None)
     super().__init__(dataset, **kwargs)
예제 #5
0
    def __init__(self, root, batch_size, train=True):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        dataset = datasets.MNIST(root, train=train, transform=transform, download=True)
        sampler = None
        if train and distributed_is_initialized():
            sampler = data.DistributedSampler(dataset)

        super(MNISTDataLoader, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=(sampler is None),
            sampler=sampler,
        )
예제 #6
0
    def __init__(self, root_folder, mini_batch, train=True):
        """ Initializes the class
            Args:
                root_folder (str): The path of a root directory containing a TrainData and
                    ValidationData subdirectories.
                mini-batch (int): size of a mini batch
                train (boolean): If True, the object is used for training. Otherwise for validation.
        """
        self.root_folder = root_folder    # Root folder for the train and validation sets
        self.mb_size = mini_batch

        if train:
            self.path = os.path.join(self.root_folder, dstUt.DATA_DIR['t'])
            self.transforms = transforms.Compose([
                transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
                transforms.RandomRotation(degrees=10),
                transforms.RandomHorizontalFlip(),
                transforms.CenterCrop(size=224),  # ImageNet standards
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet st
            ])
            self.dataset = datasets.ImageFolder(self.path, self.transforms)
            self.class_names = self.dataset.classes
            self.number_classes = len(self.class_names)
            self.data_size = len(self.dataset)

        else:
            self.path = os.path.join(self.root_folder, dstUt.DATA_DIR['v'])
            self.transforms = transforms.Compose([
                transforms.Resize(size=256),
                transforms.CenterCrop(size=224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet st
            ])
            self.dataset = datasets.ImageFolder(self.path, self.transforms)
            self.class_names = self.dataset.classes
            self.number_classes = len(self.class_names)
            self.data_size = len(self.dataset)

        sampler = None
        if train and distributed_is_initialized():
            sampler = data.DistributedSampler(self.dataset)

        self.loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.mb_size,
                                                  shuffle=(sampler is None), sampler=sampler)
예제 #7
0
 def get_data_loader(self, examples, args):
     features_0 = bert.convert_examples_to_features(
         [x[0] for x in examples], self.get_labels(),
         args.max_seq_length, self.tokenizer)
     features_1 = bert.convert_examples_to_features(
         [x[1] for x in examples], self.get_labels(),
         args.max_seq_length, self.tokenizer)
     features = list(zip(features_0, features_1))
     input_ids_0 = torch.tensor([f[0].input_ids for f in features],
                                dtype=torch.long)
     input_mask_0 = torch.tensor([f[0].input_mask for f in features],
                                 dtype=torch.long)
     segment_ids_0 = torch.tensor([f[0].segment_ids for f in features],
                                  dtype=torch.long)
     input_ids_1 = torch.tensor([f[1].input_ids for f in features],
                                dtype=torch.long)
     input_mask_1 = torch.tensor([f[1].input_mask for f in features],
                                 dtype=torch.long)
     segment_ids_1 = torch.tensor([f[1].segment_ids for f in features],
                                  dtype=torch.long)
     label_ids = torch.tensor([f[0].label_id for f in features],
                              dtype=torch.long)
     ids = [x[0].guid for x in examples]
     tensors = td.TensorDataset(
         input_ids_0, input_mask_0, segment_ids_0,
         input_ids_1, input_mask_1, segment_ids_1,
         label_ids)
     train_data = ARCTDataset(ids, tensors)
     if args.local_rank == -1:
         train_sampler = td.RandomSampler(train_data)
     else:
         train_sampler = td.DistributedSampler(train_data)
     data_loader = td.DataLoader(
         dataset=train_data,
         sampler=train_sampler,
         batch_size=args.train_batch_size,
         collate_fn=collate)
     return data_loader
예제 #8
0
파일: train.py 프로젝트: lvcat/LEDNet
    def __init__(self, args):
        self.device = torch.device(args.device)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, args.batch_size,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=args.max_iter)
        self.train_loader = data.DataLoader(trainset,
                                            batch_sampler=train_sampler,
                                            pin_memory=True,
                                            num_workers=args.workers)
        if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, args.test_batch_size, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=args.workers,
                pin_memory=True)

        # create network
        self.net = LEDNet(trainset.NUM_CLASS)

        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        self.net.to(self.device)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                self.net.load_state_dict(torch.load(args.resume))
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if args.ohem:
            min_kept = args.batch_size * args.crop_size**2 // 16
            self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7,
                                                         min_kept=min_kept,
                                                         use_weight=False)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss()

        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=args.max_iter,
                                      warmup_factor=args.warmup_factor,
                                      warmup_iters=args.warmup_iters,
                                      power=0.9)

        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)
        self.args = args
예제 #9
0
    def __init__(self, args):
        self.device = torch.device(args.device)
        # network
        net_name = '_'.join(('yolo3', args.network, args.dataset))
        self.save_prefix = net_name
        self.net = get_model(net_name, pretrained_base=True)
        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        if args.resume.strip():
            logger.info("Resume from the model {}".format(args.resume))
            self.net.load_state_dict(torch.load(args.resume.strip()))
        else:
            logger.info("Init from base net {}".format(args.network))
        classes, anchors = self.net.num_class, self.net.anchors
        self.net.set_nms(nms_thresh=0.45, nms_topk=400)
        if args.label_smooth:
            self.net._target_generator._label_smooth = True
        self.net.to(self.device)
        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net, device_ids=[args.local_rank], output_device=args.local_rank)

        # dataset and dataloader
        train_dataset = get_train_data(args.dataset, args.mixup)
        width, height = args.data_shape, args.data_shape
        batchify_fn = Tuple(*([Stack() for _ in range(6)] + [Pad(axis=0, pad_val=-1) for _ in range(1)]))
        train_dataset = train_dataset.transform(
            YOLO3DefaultTrainTransform(width, height, classes, anchors, mixup=args.mixup))
        args.per_iter = len(train_dataset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(train_dataset)
        else:
            sampler = data.RandomSampler(train_dataset)
        train_sampler = data.sampler.BatchSampler(sampler=sampler, batch_size=args.batch_size,
                                                  drop_last=False)
        train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=args.max_iter)
        if args.no_random_shape:
            self.train_loader = data.DataLoader(train_dataset, batch_sampler=train_sampler, pin_memory=True,
                                                collate_fn=batchify_fn, num_workers=args.num_workers)
        else:
            transform_fns = [YOLO3DefaultTrainTransform(x * 32, x * 32, classes, anchors, mixup=args.mixup)
                             for x in range(10, 20)]
            self.train_loader = RandomTransformDataLoader(transform_fns, train_dataset, batch_sampler=train_sampler,
                                                          collate_fn=batchify_fn, num_workers=args.num_workers)
        if args.eval_epoch > 0:
            # TODO: rewrite it
            val_dataset, self.metric = get_test_data(args.dataset)
            val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
            val_dataset = val_dataset.transform(YOLO3DefaultValTransform(width, height))
            val_sampler = make_data_sampler(val_dataset, False, args.distributed)
            val_batch_sampler = data.BatchSampler(val_sampler, args.test_batch_size, False)
            self.val_loader = data.DataLoader(val_dataset, batch_sampler=val_batch_sampler,
                                              collate_fn=val_batchify_fn, num_workers=args.num_workers)

        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum,
                                   weight_decay=args.wd)
        if args.lr_mode == 'cos':
            self.scheduler = WarmupCosineLR(optimizer=self.optimizer, T_max=args.max_iter,
                                            warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters)
        elif args.lr_mode == 'step':
            lr_decay = float(args.lr_decay)
            milestones = sorted([float(ls) * args.per_iter for ls in args.lr_decay_epoch.split(',') if ls.strip()])
            self.scheduler = WarmupMultiStepLR(optimizer=self.optimizer, milestones=milestones, gamma=lr_decay,
                                               warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters)
        else:
            raise ValueError('illegal scheduler type')
        self.args = args
예제 #10
0
파일: base.py 프로젝트: JarvisLL/ACE
    def __init__(self, config,args,logger):
        self.DISTRIBUTED,self.DEVICE = ptutil.init_environment(config,args)
        self.LR = config.TRAIN.LR * len(config.GPUS)
        self.device = torch.device(self.DEVICE)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        
        if config.DATASET.IMG_TRANSFORM:
            data_kwargs = {'transform':input_transform, 'base_size':config.DATASET.BASE_SIZE,
                           'crop_size':config.DATASET.CROP_SIZE}
        else:
            data_kwargs = {'transform':None, 'base_size':config.DATASET.BASE_SIZE,
                           'crop_size':config.DATASET.CROP_SIZE}
        trainset = get_segmentation_dataset(
            config.DATASET.NAME, split=config.TRAIN.TRAIN_SPLIT, mode='train', **data_kwargs)
        self.per_iter = len(trainset) // (len(config.GPUS) * config.TRAIN.BATCH_SIZE)
        self.max_iter = config.TRAIN.EPOCHS * self.per_iter
        if self.DISTRIBUTED:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, config.TRAIN.BATCH_SIZE, True)
        train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=self.max_iter)
        self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=config.DATASET.PIN_MEMORY,
                                            num_workers=config.DATASET.WORKERS)
        if not config.TRAIN.SKIP_EVAL or 0 < config.TRAIN.EVAL_EPOCHS < config.TRAIN.EPOCHS:
            valset = get_segmentation_dataset(config.DATASET.NAME, split='val', mode='val', **data_kwargs)
            val_sampler = make_data_sampler(valset, False, self.DISTRIBUTED)
            val_batch_sampler = data.sampler.BatchSampler(val_sampler, config.TEST.TEST_BATCH_SIZE, False)
            self.valid_loader = data.DataLoader(valset, batch_sampler=val_batch_sampler,
                                                num_workers=config.DATASET.WORKERS, pin_memory=config.DATASET.PIN_MEMORY)
        # create network
        self.net = get_segmentation_model(config.MODEL.NAME,nclass=trainset.NUM_CLASS).cuda()
        if self.DISTRIBUTED:
            if config.TRAIN.MIXED_PRECISION:
                self.net = apex.parallel.convert_syncbn_model(self.net)
            else:
                self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        if config.TRAIN.RESUME != '':
            self.net = ptutil.model_resume(self.net,config.TRAIN.RESUME,logger).to(self.device)
        # self.net.to(self.device)
        assert config.TRAIN.SEG_LOSS in ('focalloss2d', 'mixsoftmaxcrossentropyohemloss', 'mixsoftmaxcrossentropy'), 'cannot support {}'.format(config.TRAIN.SEG_LOSS)
        if config.TRAIN.SEG_LOSS == 'focalloss2d':
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,gamma=2., use_weight=False, size_average=True, ignore_index=config.DATASET.IGNORE_INDEX)
        elif config.TRAIN.SEG_LOSS == 'mixsoftmaxcrossentropyohemloss':
            min_kept = int(config.TRAIN.BATCH_SIZE // len(config.GPUS) * config.DATASET.CROP_SIZE ** 2 // 16)
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,min_kept=min_kept,ignore_index =config.DATASET.IGNORE_INDEX).to(self.device)
        else:
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,ignore_index=config.DATASET.IGNORE_INDEX)

        self.optimizer = optim.SGD(self.net.parameters(), lr=self.LR, momentum=config.TRAIN.MOMENTUM,
                                   weight_decay=config.TRAIN.WEIGHT_DECAY)
        self.scheduler = WarmupPolyLR(self.optimizer, T_max=self.max_iter, warmup_factor=config.TRAIN.WARMUP_FACTOR,
                                      warmup_iters=config.TRAIN.WARMUP_ITERS, power=0.9)
        # self.net.apply(fix_bn)
        if config.TRAIN.MIXED_PRECISION:
            self.dtype = torch.half
            self.net,self.optimizer = amp.initialize(self.net,self.optimizer,opt_level=config.TRAIN.MIXED_OPT_LEVEL)
        else:
            self.dtype = torch.float
        if self.DISTRIBUTED:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net, device_ids=[args.local_rank], output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.NUM_CLASS)
        self.config = config
        self.logger = logger
        ptutil.mkdir(self.config.TRAIN.SAVE_DIR)
        model_path = os.path.join(self.config.TRAIN.SAVE_DIR,"{}_{}_{}_init.pth"
                                  .format(config.MODEL.NAME,  config.TRAIN.SEG_LOSS, config.DATASET.NAME))
        ptutil.save_model(self.net,model_path,self.logger)
예제 #11
0
    def __init__(self, args):
        self.device = torch.device(args.device)
        self.save_prefix = '_'.join((args.model, args.backbone, args.dataset))
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iter = args.epochs * args.per_iter
        if args.distributed:
            sampler = data.DistributedSampler(trainset)
        else:
            sampler = data.RandomSampler(trainset)
        train_sampler = data.sampler.BatchSampler(sampler, args.batch_size,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=args.max_iter)
        self.train_loader = data.DataLoader(trainset,
                                            batch_sampler=train_sampler,
                                            pin_memory=True,
                                            num_workers=args.workers)
        if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
            valset = get_segmentation_dataset(args.dataset,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, args.test_batch_size, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=args.workers,
                pin_memory=True)

        # create network
        if args.model_zoo is not None:
            self.net = get_model(args.model_zoo, pretrained=True)
        else:
            kwargs = {'oc': args.oc} if args.model == 'ocnet' else {}
            self.net = get_segmentation_model(model=args.model,
                                              dataset=args.dataset,
                                              backbone=args.backbone,
                                              aux=args.aux,
                                              dilated=args.dilated,
                                              jpu=args.jpu,
                                              crop_size=args.crop_size,
                                              **kwargs)
        if args.distributed:
            self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
        self.net.to(self.device)
        # resume checkpoint if needed
        if args.resume is not None:
            if os.path.isfile(args.resume):
                self.net.load_state_dict(torch.load(args.resume))
            else:
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

        # create criterion
        if args.ohem:
            min_kept = args.batch_size * args.crop_size**2 // 16
            self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7,
                                                         min_kept=min_kept,
                                                         use_weight=False)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss(
                args.aux, aux_weight=args.aux_weight)

        # optimizer and lr scheduling
        params_list = [{
            'params': self.net.base1.parameters(),
            'lr': args.lr
        }, {
            'params': self.net.base2.parameters(),
            'lr': args.lr
        }, {
            'params': self.net.base3.parameters(),
            'lr': args.lr
        }]
        if hasattr(self.net, 'others'):
            for name in self.net.others:
                params_list.append({
                    'params':
                    getattr(self.net, name).parameters(),
                    'lr':
                    args.lr * 10
                })
        if hasattr(self.net, 'JPU'):
            params_list.append({
                'params': self.net.JPU.parameters(),
                'lr': args.lr * 10
            })
        self.optimizer = optim.SGD(params_list,
                                   lr=args.lr,
                                   momentum=args.momentum,
                                   weight_decay=args.weight_decay)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=args.max_iter,
                                      warmup_factor=args.warmup_factor,
                                      warmup_iters=args.warmup_iters,
                                      power=0.9)

        if args.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(
                self.net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)
        self.args = args
예제 #12
0
    def __init__(self, config):
        self.config = config
        self.scalers = {
            'data': get_scaler(config.data.scaler.n_quantiles),
            'context': get_scaler(config.data.scaler.n_quantiles),
            'weight': get_scaler(config.experiment.weights.n_quantiles),
        }

        if self.config.experiment.weights.positive:
            self.scalers['weight'] = MinMaxScaler()
        if not self.config.experiment.weights.enable:
            self.scalers['weight'] = NoneProcessor()

        if config.data.download:
            if not os.path.exists(config.data.data_path):
                os.makedirs(config.data.data_path)
            log.info('config.data.download is True, starting dowload')
            target_path = os.path.join(config.data.data_path, 'data-calibsample')
            if os.path.exists(target_path):
                print("It seems that data is already downloaded. Are you sure?")
            os.system(f"wget https://cernbox.cern.ch/index.php/s/Fjf3UNgvlRVa4Td/download -O {target_path + '.tar.gz'}")
            log.info('files downloaded, starting unpacking')
            os.system(f"tar xvf {target_path + '.tar.gz'}")
            log.info('files unpacked')

        # todo rethink
        config.data.data_path = os.path.join(config.data.data_path, 'data-calibsample')

        table = np.array(get_particle_table(config.data.data_path, config.experiment.particle))
        train_table, val_table = train_test_split(table, test_size=self.config.data.val_size, random_state=42)
        self.scalers['data'].fit(train_table[:, :config.experiment.data.data_dim])
        self.scalers['context'].fit(
            train_table[:, config.experiment.data.data_dim:
                           config.experiment.data.data_dim + config.experiment.data.context_dim]
        )
        self.scalers['weight'].fit(train_table[:, -1].reshape(-1, 1))
        # todo assert weight on last col, mb add to config

        train_table = np.concatenate([
            self.scalers['data'].transform(train_table[:, :config.experiment.data.data_dim]),
            self.scalers['context'].transform(
                train_table[:, config.experiment.data.data_dim:
                               config.experiment.data.data_dim + config.experiment.data.context_dim]),
            self.scalers['weight'].transform(train_table[:, -1].reshape(-1, 1))
        ], axis=1)
        val_table = np.concatenate([
            self.scalers['data'].transform(val_table[:, :config.experiment.data.data_dim]),
            self.scalers['context'].transform(
                val_table[:, config.experiment.data.data_dim:
                             config.experiment.data.data_dim + config.experiment.data.context_dim]),
            self.scalers['weight'].transform(val_table[:, -1].reshape(-1, 1))
        ], axis=1)

        train_dataset = ParticleDataset(config, train_table)
        self.train_loader = data.DataLoader(
            dataset=train_dataset,
            batch_size=config.experiment.batch_size,
            sampler=data.DistributedSampler(train_dataset) if config.utils.use_ddp else None,
            shuffle=True if not config.utils.use_ddp else None,
            pin_memory=True,
            drop_last=True
        )
        val_dataset = ParticleDataset(config, val_table)
        self.val_loader = data.DataLoader(
            dataset=val_dataset,
            batch_size=config.experiment.batch_size,
            sampler=None,
            shuffle=False,
            drop_last=True
        )
예제 #13
0
    def __init__(self, config, args, logger):
        self.DISTRIBUTED, self.DEVICE = ptutil.init_environment(config, args)
        self.LR = config.TRAIN.LR * len(config.GPUS)  # scale by num gpus
        self.GENERATOR_LR = config.TRAIN.GENERATOR_LR * len(config.GPUS)
        self.device = torch.device(self.DEVICE)
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        if config.DATASET.IMG_TRANSFORM:
            data_kwargs = {
                "transform": input_transform,
                "base_size": config.DATASET.BASE_SIZE,
                "crop_size": config.DATASET.CROP_SIZE
            }
        else:
            data_kwargs = {
                "transform": None,
                "base_size": config.DATASET.BASE_SIZE,
                "crop_size": config.DATASET.CROP_SIZE
            }
        # target dataset
        targetdataset = get_segmentation_dataset('targetdataset',
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        trainset = get_segmentation_dataset(config.DATASET.NAME,
                                            split=config.TRAIN.TRAIN_SPLIT,
                                            mode='train',
                                            **data_kwargs)
        self.per_iter = len(trainset) // (len(config.GPUS) *
                                          config.TRAIN.BATCH_SIZE)
        targetset_per_iter = len(targetdataset) // (len(config.GPUS) *
                                                    config.TRAIN.BATCH_SIZE)
        targetset_max_iter = config.TRAIN.EPOCHS * targetset_per_iter
        self.max_iter = config.TRAIN.epochs * self.per_iter
        if self.DISTRIBUTED:
            sampler = data.DistributedSampler(trainset)
            target_sampler = data.DistributedSampler(targetdataset)
        else:
            sampler = data.RandomSampler(trainset)
            target_sampler = data.RandomSampler(targetdataset)
        train_sampler = data.sampler.BatchSampler(sampler,
                                                  config.TRAIN.BATCH_SIZE,
                                                  True)
        train_sampler = IterationBasedBatchSampler(
            train_sampler, num_iterations=self.max_iter)
        self.train_loader = data.DataLoader(
            trainset,
            batch_sampler=train_sampler,
            pin_memory=config.DATASET.PIN_MEMORY,
            num_workers=config.DATASET.WORKERS)
        target_train_sampler = data.sampler.BatchSampler(
            target_sampler, config.TRAIN.BATCH_SIZE, True)
        target_train_sampler = IterationBasedBatchSampler(
            target_train_sampler, num_iterations=targetset_max_iter)
        self.target_loader = data.DataLoader(
            targetdataset,
            batch_sampler=target_train_sampler,
            pin_memory=False,
            num_workers=config.DATASET.WORKERS)
        self.target_trainloader_iter = enumerate(self.target_loader)
        if not config.TRAIN.SKIP_EVAL or 0 < config.TRAIN.EVAL_EPOCH < config.TRAIN.EPOCHS:
            valset = get_segmentation_dataset(config.DATASET.NAME,
                                              split='val',
                                              mode='val',
                                              **data_kwargs)
            val_sampler = make_data_sampler(valset, False, self.DISTRIBUTED)
            val_batch_sampler = data.sampler.BatchSampler(
                val_sampler, config.TEST.TEST_BATCH_SIZE, False)
            self.valid_loader = data.DataLoader(
                valset,
                batch_sampler=val_batch_sampler,
                num_workers=config.DATASET.WORKERS,
                pin_memory=False)

        # create network
        self.seg_net = get_segmentation_model(
            config.MODEL.SEG_NET, nclass=trainset.NUM_CLASS).cuda()
        self.feature_extracted = vgg19(pretrained=True)
        self.generator = get_segmentation_model(config.MODEL.TARGET_GENERATOR)

        if self.DISTRIBUTED:
            if config.TRAIN.MIXED_PRECISION:
                self.seg_net = apex.parallel.convert_syncbn_model(self.seg_net)
                self.feature_extracted = apex.parallel.convert_syncbn_model(
                    self.feature_extracted)
                self.generator = apex.parallel.convert_syncbn_model(
                    self.generator)
            else:
                self.seg_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.seg_net)
                self.feature_extracted = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.feature_extracted)
                self.generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                    self.generator)

        # resume checkpoint if needed
        if config.TRAIN.RESUME != '':
            logger.info('loading {} parameter ...'.format(
                config.MODEL.SEG_NET))
            self.seg_net = ptutil.model_resume(self.seg_net,
                                               config.TRAIN.RESUME,
                                               logger).to(self.device)
        if config.TRAIN.RESUME_GENERATOR != '':
            logger.info('loading {} parameter ...'.format(
                config.MODEL.TARGET_GENERATOR))
            self.generator = ptutil.model_resume(self.generator,
                                                 config.TRAIN.RESUME_GENERATOR,
                                                 logger).to(self.device)

        self.feature_extracted.to(self.device)
        # create criterion
        assert config.TRAIN.SEG_LOSS in (
            'focalloss2d', 'mixsoftmaxcrossentropyohemloss',
            'mixsoftmaxcrossentropy'), 'cannot support {}'.format(
                config.TRAIN.SEG_LOSS)
        if config.TRAIN.SEG_LOSS == 'focalloss2d':
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,
                                      gamma=2.,
                                      use_weight=False,
                                      size_average=True,
                                      ignore_index=config.DATASET.IGNORE_INDEX)
        elif config.TRAIN.SEG_LOSS == 'mixsoftmaxcrossentropyohemloss':
            min_kept = int(config.TRAIN.BATCH_SIZE // len(config.GPUS) *
                           config.DATASET.CROP_SIZE**2 // 16)
            self.criterion = get_loss(config.TRAIN.SEG_LOSS,
                                      min_kept=min_kept,
                                      ignore_index=-1).to(self.device)
        else:
            self.criterion = get_loss(config.TRAIN.SEG_LOSS, ignore_index=-1)

        self.gen_criterion = get_loss('mseloss')
        self.kl_criterion = get_loss('criterionkldivergence')
        # optimizer and lr scheduling
        self.optimizer = optim.SGD(self.seg_net.parameters(),
                                   lr=self.LR,
                                   momentum=config.TRAIN.MOMENTUM,
                                   weight_decay=config.TRAIN.WEIGHT_DECAY)
        self.scheduler = WarmupPolyLR(self.optimizer,
                                      T_max=self.max_iter,
                                      warmup_factor=config.TRAIN.WARMUP_FACTOR,
                                      warmup_iters=config.TRAIN.WARMUP_ITERS,
                                      power=0.9)
        self.gen_optimizer = optim.SGD(self.generator.parameters(),
                                       lr=self.GENERATOR_LR,
                                       momentum=config.TRAIN.MOMENTUM,
                                       weight_decay=config.TRAIN.WEIGHT_DECAY)
        self.gen_scheduler = WarmupPolyLR(
            self.gen_optimizer,
            T_max=self.max_iter,
            warmup_factor=config.TRAIN.WARMUP_FACTOR,
            warmup_iters=config.TRAIN.WARMUP_ITERS,
            power=0.9)

        if config.TRAIN.MIXED_PRECISION:
            [self.seg_net, self.generator
             ], [self.optimizer, self.gen_optimizer
                 ] = amp.initialize([self.seg_net, self.generator],
                                    [self.optimizer, self.gen_optimizer],
                                    opt_level=config.TRAIN.MIXED_OPT_LEVEL)
            self.dtype = torch.half
        else:
            self.dtype = torch.float
        if self.DISTRIBUTED:
            self.seg_net = torch.nn.parallel.DistributedDataParallel(
                self.seg_net,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
            self.generator = torch.nn.parallel.DistributedDataParallel(
                self.generator,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
            self.feature_extracted = torch.nn.parallel.DistributedDataParallel(
                self.feature_extracted,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(trainset.NUM_CLASS)
        self.config = config
        self.logger = logger
        self.seg_dir = os.path.join(self.config.TRAIN.SAVE_DIR, 'seg')
        ptutil.mkdir(self.seg_dir)
        self.generator_dir = os.path.join(self.config.TRAIN.SAVE_DIR,
                                          'generator')
        ptutil.mkdir(self.generator_dir)