def __init__(self, root, batch_size, train=True): # from torchvision import transforms transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # from torchvision import datasets dataset = datasets.CIFAR10(root, train=train, transform=transform, download=True) # 分布式 sampler = None if train and distributed_is_initialized(): sampler = data.DistributedSampler(dataset) super(CIFAR10DataLoader, self).__init__( dataset, batch_size=batch_size, shuffle=(sampler is None), sampler=sampler, )
def __init__(self, dataset: dat.Dataset, num_workers: int = 4, chunk_size: int = 64000, batch_size: int = 16, distributed: bool = False, train: bool = True) -> None: self.dataset = dataset self.train = train self.batch_size = batch_size self.splitter = ChunkSplitter(chunk_size, train=train, hop=chunk_size // 2) if distributed: self.sampler = dat.DistributedSampler( dataset, shuffle=train, num_replicas=dist.world_size(), rank=dist.rank()) else: self.sampler = None # just return batch of egs, support multiple workers # NOTE: batch_size is not the batch_size of the audio chunk self.eg_loader = dat.DataLoader(self.dataset, batch_size=min(batch_size, 64), num_workers=num_workers, sampler=self.sampler, shuffle=(train and self.sampler is None), collate_fn=self._collate)
def create_data_loader( train_cfg: dict, data_dir: pathlib.Path, anchors: torch.tensor, batch_size: int, world_size: int, val: bool, image_size: int, num_classes: int, dataset_name: str, ) -> Tuple[data.DataLoader, data.Sampler]: """Simple function to create the dataloaders for training and evaluation. Args: training_cfg: The parameters related to the training regime. data_dir: The directory where the images are located. anchors: The tensor of anchors in the model. batch_size: The loader's batch size. world_size: World size is needed to determine if a distributed sampler is needed. val: Whether or not this loader is for validation. image_size: Size of input images into the model. NOTE: we force square images. Returns: The dataloader and the loader's sampler. For _training_ we have to set the sampler's epoch to reshuffle. """ assert data_dir.is_dir(), data_dir meta = pathlib.Path(data_dir / "annotations.json") dataset_ = dataset.DetectionDataset( data_dir=data_dir / "images", metadata_path=meta, img_width=image_size[0], img_height=image_size[1], validation=val, ) print(dataset_) # If using distributed training, use a DistributedSampler to load exclusive sets # of data per process. sampler = None if world_size > 1: sampler = data.DistributedSampler(dataset_, shuffle=~val) if val: collate_fn = collate.CollateVal() else: collate_fn = collate.Collate(num_classes=num_classes, original_anchors=anchors) loader = data.DataLoader( dataset_, batch_size=batch_size, pin_memory=True, sampler=sampler, collate_fn=collate_fn, num_workers=max(torch.multiprocessing.cpu_count() // world_size, 8), drop_last=True if val else False, ) return loader, sampler
def __init__(self, version, phase, transform_name, factor=0.8, category_pool_name=None, label_indent=-1, lmdb=None, distributed=False, **kwargs): dataset = _MiniImagenetLMDBHorizontal(MINIIMAGENET_DATASET_DIR, version, phase, factor, category_pool_name, label_indent, lmdb) transform_fn_x, transform_fn_y = dataset_transforms[ self.__class__.__name__][transform_name] collate_fn = functools.partial(default_collate_fn, transform_fn_x=transform_fn_x, transform_fn_y=transform_fn_y) if not not distributed: if not torch.distributed.is_available(): raise RuntimeError( "Requires distributed package to be available.") if not torch.distributed.is_initialized(): raise RuntimeError( "Requires distributed progress group initialized first.") sampler = torda.DistributedSampler(dataset) else: sampler = None kwargs['collate_fn'] = collate_fn kwargs['sampler'] = sampler kwargs['shuffle'] = kwargs.get('shuffle', False) and (sampler is None) super().__init__(dataset, **kwargs)
def __init__(self, root, batch_size, train=True): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) dataset = datasets.MNIST(root, train=train, transform=transform, download=True) sampler = None if train and distributed_is_initialized(): sampler = data.DistributedSampler(dataset) super(MNISTDataLoader, self).__init__( dataset, batch_size=batch_size, shuffle=(sampler is None), sampler=sampler, )
def __init__(self, root_folder, mini_batch, train=True): """ Initializes the class Args: root_folder (str): The path of a root directory containing a TrainData and ValidationData subdirectories. mini-batch (int): size of a mini batch train (boolean): If True, the object is used for training. Otherwise for validation. """ self.root_folder = root_folder # Root folder for the train and validation sets self.mb_size = mini_batch if train: self.path = os.path.join(self.root_folder, dstUt.DATA_DIR['t']) self.transforms = transforms.Compose([ transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), transforms.RandomRotation(degrees=10), transforms.RandomHorizontalFlip(), transforms.CenterCrop(size=224), # ImageNet standards transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet st ]) self.dataset = datasets.ImageFolder(self.path, self.transforms) self.class_names = self.dataset.classes self.number_classes = len(self.class_names) self.data_size = len(self.dataset) else: self.path = os.path.join(self.root_folder, dstUt.DATA_DIR['v']) self.transforms = transforms.Compose([ transforms.Resize(size=256), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet st ]) self.dataset = datasets.ImageFolder(self.path, self.transforms) self.class_names = self.dataset.classes self.number_classes = len(self.class_names) self.data_size = len(self.dataset) sampler = None if train and distributed_is_initialized(): sampler = data.DistributedSampler(self.dataset) self.loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.mb_size, shuffle=(sampler is None), sampler=sampler)
def get_data_loader(self, examples, args): features_0 = bert.convert_examples_to_features( [x[0] for x in examples], self.get_labels(), args.max_seq_length, self.tokenizer) features_1 = bert.convert_examples_to_features( [x[1] for x in examples], self.get_labels(), args.max_seq_length, self.tokenizer) features = list(zip(features_0, features_1)) input_ids_0 = torch.tensor([f[0].input_ids for f in features], dtype=torch.long) input_mask_0 = torch.tensor([f[0].input_mask for f in features], dtype=torch.long) segment_ids_0 = torch.tensor([f[0].segment_ids for f in features], dtype=torch.long) input_ids_1 = torch.tensor([f[1].input_ids for f in features], dtype=torch.long) input_mask_1 = torch.tensor([f[1].input_mask for f in features], dtype=torch.long) segment_ids_1 = torch.tensor([f[1].segment_ids for f in features], dtype=torch.long) label_ids = torch.tensor([f[0].label_id for f in features], dtype=torch.long) ids = [x[0].guid for x in examples] tensors = td.TensorDataset( input_ids_0, input_mask_0, segment_ids_0, input_ids_1, input_mask_1, segment_ids_1, label_ids) train_data = ARCTDataset(ids, tensors) if args.local_rank == -1: train_sampler = td.RandomSampler(train_data) else: train_sampler = td.DistributedSampler(train_data) data_loader = td.DataLoader( dataset=train_data, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate) return data_loader
def __init__(self, args): self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) args.per_iter = len(trainset) // (args.num_gpus * args.batch_size) args.max_iter = args.epochs * args.per_iter if args.distributed: sampler = data.DistributedSampler(trainset) else: sampler = data.RandomSampler(trainset) train_sampler = data.sampler.BatchSampler(sampler, args.batch_size, True) train_sampler = IterationBasedBatchSampler( train_sampler, num_iterations=args.max_iter) self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=True, num_workers=args.workers) if not args.skip_eval or 0 < args.eval_epochs < args.epochs: valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, args.distributed) val_batch_sampler = data.sampler.BatchSampler( val_sampler, args.test_batch_size, False) self.valid_loader = data.DataLoader( valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network self.net = LEDNet(trainset.NUM_CLASS) if args.distributed: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) self.net.to(self.device) # resume checkpoint if needed if args.resume is not None: if os.path.isfile(args.resume): self.net.load_state_dict(torch.load(args.resume)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) # create criterion if args.ohem: min_kept = args.batch_size * args.crop_size**2 // 16 self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7, min_kept=min_kept, use_weight=False) else: self.criterion = MixSoftmaxCrossEntropyLoss() # optimizer and lr scheduling self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.scheduler = WarmupPolyLR(self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, power=0.9) if args.distributed: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # evaluation metrics self.metric = SegmentationMetric(trainset.num_class) self.args = args
def __init__(self, args): self.device = torch.device(args.device) # network net_name = '_'.join(('yolo3', args.network, args.dataset)) self.save_prefix = net_name self.net = get_model(net_name, pretrained_base=True) if args.distributed: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) if args.resume.strip(): logger.info("Resume from the model {}".format(args.resume)) self.net.load_state_dict(torch.load(args.resume.strip())) else: logger.info("Init from base net {}".format(args.network)) classes, anchors = self.net.num_class, self.net.anchors self.net.set_nms(nms_thresh=0.45, nms_topk=400) if args.label_smooth: self.net._target_generator._label_smooth = True self.net.to(self.device) if args.distributed: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # dataset and dataloader train_dataset = get_train_data(args.dataset, args.mixup) width, height = args.data_shape, args.data_shape batchify_fn = Tuple(*([Stack() for _ in range(6)] + [Pad(axis=0, pad_val=-1) for _ in range(1)])) train_dataset = train_dataset.transform( YOLO3DefaultTrainTransform(width, height, classes, anchors, mixup=args.mixup)) args.per_iter = len(train_dataset) // (args.num_gpus * args.batch_size) args.max_iter = args.epochs * args.per_iter if args.distributed: sampler = data.DistributedSampler(train_dataset) else: sampler = data.RandomSampler(train_dataset) train_sampler = data.sampler.BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False) train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=args.max_iter) if args.no_random_shape: self.train_loader = data.DataLoader(train_dataset, batch_sampler=train_sampler, pin_memory=True, collate_fn=batchify_fn, num_workers=args.num_workers) else: transform_fns = [YOLO3DefaultTrainTransform(x * 32, x * 32, classes, anchors, mixup=args.mixup) for x in range(10, 20)] self.train_loader = RandomTransformDataLoader(transform_fns, train_dataset, batch_sampler=train_sampler, collate_fn=batchify_fn, num_workers=args.num_workers) if args.eval_epoch > 0: # TODO: rewrite it val_dataset, self.metric = get_test_data(args.dataset) val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1)) val_dataset = val_dataset.transform(YOLO3DefaultValTransform(width, height)) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = data.BatchSampler(val_sampler, args.test_batch_size, False) self.val_loader = data.DataLoader(val_dataset, batch_sampler=val_batch_sampler, collate_fn=val_batchify_fn, num_workers=args.num_workers) # optimizer and lr scheduling self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) if args.lr_mode == 'cos': self.scheduler = WarmupCosineLR(optimizer=self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters) elif args.lr_mode == 'step': lr_decay = float(args.lr_decay) milestones = sorted([float(ls) * args.per_iter for ls in args.lr_decay_epoch.split(',') if ls.strip()]) self.scheduler = WarmupMultiStepLR(optimizer=self.optimizer, milestones=milestones, gamma=lr_decay, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters) else: raise ValueError('illegal scheduler type') self.args = args
def __init__(self, config,args,logger): self.DISTRIBUTED,self.DEVICE = ptutil.init_environment(config,args) self.LR = config.TRAIN.LR * len(config.GPUS) self.device = torch.device(self.DEVICE) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader if config.DATASET.IMG_TRANSFORM: data_kwargs = {'transform':input_transform, 'base_size':config.DATASET.BASE_SIZE, 'crop_size':config.DATASET.CROP_SIZE} else: data_kwargs = {'transform':None, 'base_size':config.DATASET.BASE_SIZE, 'crop_size':config.DATASET.CROP_SIZE} trainset = get_segmentation_dataset( config.DATASET.NAME, split=config.TRAIN.TRAIN_SPLIT, mode='train', **data_kwargs) self.per_iter = len(trainset) // (len(config.GPUS) * config.TRAIN.BATCH_SIZE) self.max_iter = config.TRAIN.EPOCHS * self.per_iter if self.DISTRIBUTED: sampler = data.DistributedSampler(trainset) else: sampler = data.RandomSampler(trainset) train_sampler = data.sampler.BatchSampler(sampler, config.TRAIN.BATCH_SIZE, True) train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=self.max_iter) self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=config.DATASET.PIN_MEMORY, num_workers=config.DATASET.WORKERS) if not config.TRAIN.SKIP_EVAL or 0 < config.TRAIN.EVAL_EPOCHS < config.TRAIN.EPOCHS: valset = get_segmentation_dataset(config.DATASET.NAME, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, self.DISTRIBUTED) val_batch_sampler = data.sampler.BatchSampler(val_sampler, config.TEST.TEST_BATCH_SIZE, False) self.valid_loader = data.DataLoader(valset, batch_sampler=val_batch_sampler, num_workers=config.DATASET.WORKERS, pin_memory=config.DATASET.PIN_MEMORY) # create network self.net = get_segmentation_model(config.MODEL.NAME,nclass=trainset.NUM_CLASS).cuda() if self.DISTRIBUTED: if config.TRAIN.MIXED_PRECISION: self.net = apex.parallel.convert_syncbn_model(self.net) else: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) if config.TRAIN.RESUME != '': self.net = ptutil.model_resume(self.net,config.TRAIN.RESUME,logger).to(self.device) # self.net.to(self.device) assert config.TRAIN.SEG_LOSS in ('focalloss2d', 'mixsoftmaxcrossentropyohemloss', 'mixsoftmaxcrossentropy'), 'cannot support {}'.format(config.TRAIN.SEG_LOSS) if config.TRAIN.SEG_LOSS == 'focalloss2d': self.criterion = get_loss(config.TRAIN.SEG_LOSS,gamma=2., use_weight=False, size_average=True, ignore_index=config.DATASET.IGNORE_INDEX) elif config.TRAIN.SEG_LOSS == 'mixsoftmaxcrossentropyohemloss': min_kept = int(config.TRAIN.BATCH_SIZE // len(config.GPUS) * config.DATASET.CROP_SIZE ** 2 // 16) self.criterion = get_loss(config.TRAIN.SEG_LOSS,min_kept=min_kept,ignore_index =config.DATASET.IGNORE_INDEX).to(self.device) else: self.criterion = get_loss(config.TRAIN.SEG_LOSS,ignore_index=config.DATASET.IGNORE_INDEX) self.optimizer = optim.SGD(self.net.parameters(), lr=self.LR, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WEIGHT_DECAY) self.scheduler = WarmupPolyLR(self.optimizer, T_max=self.max_iter, warmup_factor=config.TRAIN.WARMUP_FACTOR, warmup_iters=config.TRAIN.WARMUP_ITERS, power=0.9) # self.net.apply(fix_bn) if config.TRAIN.MIXED_PRECISION: self.dtype = torch.half self.net,self.optimizer = amp.initialize(self.net,self.optimizer,opt_level=config.TRAIN.MIXED_OPT_LEVEL) else: self.dtype = torch.float if self.DISTRIBUTED: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # evaluation metrics self.metric = SegmentationMetric(trainset.NUM_CLASS) self.config = config self.logger = logger ptutil.mkdir(self.config.TRAIN.SAVE_DIR) model_path = os.path.join(self.config.TRAIN.SAVE_DIR,"{}_{}_{}_init.pth" .format(config.MODEL.NAME, config.TRAIN.SEG_LOSS, config.DATASET.NAME)) ptutil.save_model(self.net,model_path,self.logger)
def __init__(self, args): self.device = torch.device(args.device) self.save_prefix = '_'.join((args.model, args.backbone, args.dataset)) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size } trainset = get_segmentation_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) args.per_iter = len(trainset) // (args.num_gpus * args.batch_size) args.max_iter = args.epochs * args.per_iter if args.distributed: sampler = data.DistributedSampler(trainset) else: sampler = data.RandomSampler(trainset) train_sampler = data.sampler.BatchSampler(sampler, args.batch_size, True) train_sampler = IterationBasedBatchSampler( train_sampler, num_iterations=args.max_iter) self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=True, num_workers=args.workers) if not args.skip_eval or 0 < args.eval_epochs < args.epochs: valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, args.distributed) val_batch_sampler = data.sampler.BatchSampler( val_sampler, args.test_batch_size, False) self.valid_loader = data.DataLoader( valset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network if args.model_zoo is not None: self.net = get_model(args.model_zoo, pretrained=True) else: kwargs = {'oc': args.oc} if args.model == 'ocnet' else {} self.net = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, dilated=args.dilated, jpu=args.jpu, crop_size=args.crop_size, **kwargs) if args.distributed: self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) self.net.to(self.device) # resume checkpoint if needed if args.resume is not None: if os.path.isfile(args.resume): self.net.load_state_dict(torch.load(args.resume)) else: raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) # create criterion if args.ohem: min_kept = args.batch_size * args.crop_size**2 // 16 self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7, min_kept=min_kept, use_weight=False) else: self.criterion = MixSoftmaxCrossEntropyLoss( args.aux, aux_weight=args.aux_weight) # optimizer and lr scheduling params_list = [{ 'params': self.net.base1.parameters(), 'lr': args.lr }, { 'params': self.net.base2.parameters(), 'lr': args.lr }, { 'params': self.net.base3.parameters(), 'lr': args.lr }] if hasattr(self.net, 'others'): for name in self.net.others: params_list.append({ 'params': getattr(self.net, name).parameters(), 'lr': args.lr * 10 }) if hasattr(self.net, 'JPU'): params_list.append({ 'params': self.net.JPU.parameters(), 'lr': args.lr * 10 }) self.optimizer = optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) self.scheduler = WarmupPolyLR(self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, warmup_iters=args.warmup_iters, power=0.9) if args.distributed: self.net = torch.nn.parallel.DistributedDataParallel( self.net, device_ids=[args.local_rank], output_device=args.local_rank) # evaluation metrics self.metric = SegmentationMetric(trainset.num_class) self.args = args
def __init__(self, config): self.config = config self.scalers = { 'data': get_scaler(config.data.scaler.n_quantiles), 'context': get_scaler(config.data.scaler.n_quantiles), 'weight': get_scaler(config.experiment.weights.n_quantiles), } if self.config.experiment.weights.positive: self.scalers['weight'] = MinMaxScaler() if not self.config.experiment.weights.enable: self.scalers['weight'] = NoneProcessor() if config.data.download: if not os.path.exists(config.data.data_path): os.makedirs(config.data.data_path) log.info('config.data.download is True, starting dowload') target_path = os.path.join(config.data.data_path, 'data-calibsample') if os.path.exists(target_path): print("It seems that data is already downloaded. Are you sure?") os.system(f"wget https://cernbox.cern.ch/index.php/s/Fjf3UNgvlRVa4Td/download -O {target_path + '.tar.gz'}") log.info('files downloaded, starting unpacking') os.system(f"tar xvf {target_path + '.tar.gz'}") log.info('files unpacked') # todo rethink config.data.data_path = os.path.join(config.data.data_path, 'data-calibsample') table = np.array(get_particle_table(config.data.data_path, config.experiment.particle)) train_table, val_table = train_test_split(table, test_size=self.config.data.val_size, random_state=42) self.scalers['data'].fit(train_table[:, :config.experiment.data.data_dim]) self.scalers['context'].fit( train_table[:, config.experiment.data.data_dim: config.experiment.data.data_dim + config.experiment.data.context_dim] ) self.scalers['weight'].fit(train_table[:, -1].reshape(-1, 1)) # todo assert weight on last col, mb add to config train_table = np.concatenate([ self.scalers['data'].transform(train_table[:, :config.experiment.data.data_dim]), self.scalers['context'].transform( train_table[:, config.experiment.data.data_dim: config.experiment.data.data_dim + config.experiment.data.context_dim]), self.scalers['weight'].transform(train_table[:, -1].reshape(-1, 1)) ], axis=1) val_table = np.concatenate([ self.scalers['data'].transform(val_table[:, :config.experiment.data.data_dim]), self.scalers['context'].transform( val_table[:, config.experiment.data.data_dim: config.experiment.data.data_dim + config.experiment.data.context_dim]), self.scalers['weight'].transform(val_table[:, -1].reshape(-1, 1)) ], axis=1) train_dataset = ParticleDataset(config, train_table) self.train_loader = data.DataLoader( dataset=train_dataset, batch_size=config.experiment.batch_size, sampler=data.DistributedSampler(train_dataset) if config.utils.use_ddp else None, shuffle=True if not config.utils.use_ddp else None, pin_memory=True, drop_last=True ) val_dataset = ParticleDataset(config, val_table) self.val_loader = data.DataLoader( dataset=val_dataset, batch_size=config.experiment.batch_size, sampler=None, shuffle=False, drop_last=True )
def __init__(self, config, args, logger): self.DISTRIBUTED, self.DEVICE = ptutil.init_environment(config, args) self.LR = config.TRAIN.LR * len(config.GPUS) # scale by num gpus self.GENERATOR_LR = config.TRAIN.GENERATOR_LR * len(config.GPUS) self.device = torch.device(self.DEVICE) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) if config.DATASET.IMG_TRANSFORM: data_kwargs = { "transform": input_transform, "base_size": config.DATASET.BASE_SIZE, "crop_size": config.DATASET.CROP_SIZE } else: data_kwargs = { "transform": None, "base_size": config.DATASET.BASE_SIZE, "crop_size": config.DATASET.CROP_SIZE } # target dataset targetdataset = get_segmentation_dataset('targetdataset', split='train', mode='train', **data_kwargs) trainset = get_segmentation_dataset(config.DATASET.NAME, split=config.TRAIN.TRAIN_SPLIT, mode='train', **data_kwargs) self.per_iter = len(trainset) // (len(config.GPUS) * config.TRAIN.BATCH_SIZE) targetset_per_iter = len(targetdataset) // (len(config.GPUS) * config.TRAIN.BATCH_SIZE) targetset_max_iter = config.TRAIN.EPOCHS * targetset_per_iter self.max_iter = config.TRAIN.epochs * self.per_iter if self.DISTRIBUTED: sampler = data.DistributedSampler(trainset) target_sampler = data.DistributedSampler(targetdataset) else: sampler = data.RandomSampler(trainset) target_sampler = data.RandomSampler(targetdataset) train_sampler = data.sampler.BatchSampler(sampler, config.TRAIN.BATCH_SIZE, True) train_sampler = IterationBasedBatchSampler( train_sampler, num_iterations=self.max_iter) self.train_loader = data.DataLoader( trainset, batch_sampler=train_sampler, pin_memory=config.DATASET.PIN_MEMORY, num_workers=config.DATASET.WORKERS) target_train_sampler = data.sampler.BatchSampler( target_sampler, config.TRAIN.BATCH_SIZE, True) target_train_sampler = IterationBasedBatchSampler( target_train_sampler, num_iterations=targetset_max_iter) self.target_loader = data.DataLoader( targetdataset, batch_sampler=target_train_sampler, pin_memory=False, num_workers=config.DATASET.WORKERS) self.target_trainloader_iter = enumerate(self.target_loader) if not config.TRAIN.SKIP_EVAL or 0 < config.TRAIN.EVAL_EPOCH < config.TRAIN.EPOCHS: valset = get_segmentation_dataset(config.DATASET.NAME, split='val', mode='val', **data_kwargs) val_sampler = make_data_sampler(valset, False, self.DISTRIBUTED) val_batch_sampler = data.sampler.BatchSampler( val_sampler, config.TEST.TEST_BATCH_SIZE, False) self.valid_loader = data.DataLoader( valset, batch_sampler=val_batch_sampler, num_workers=config.DATASET.WORKERS, pin_memory=False) # create network self.seg_net = get_segmentation_model( config.MODEL.SEG_NET, nclass=trainset.NUM_CLASS).cuda() self.feature_extracted = vgg19(pretrained=True) self.generator = get_segmentation_model(config.MODEL.TARGET_GENERATOR) if self.DISTRIBUTED: if config.TRAIN.MIXED_PRECISION: self.seg_net = apex.parallel.convert_syncbn_model(self.seg_net) self.feature_extracted = apex.parallel.convert_syncbn_model( self.feature_extracted) self.generator = apex.parallel.convert_syncbn_model( self.generator) else: self.seg_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.seg_net) self.feature_extracted = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.feature_extracted) self.generator = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.generator) # resume checkpoint if needed if config.TRAIN.RESUME != '': logger.info('loading {} parameter ...'.format( config.MODEL.SEG_NET)) self.seg_net = ptutil.model_resume(self.seg_net, config.TRAIN.RESUME, logger).to(self.device) if config.TRAIN.RESUME_GENERATOR != '': logger.info('loading {} parameter ...'.format( config.MODEL.TARGET_GENERATOR)) self.generator = ptutil.model_resume(self.generator, config.TRAIN.RESUME_GENERATOR, logger).to(self.device) self.feature_extracted.to(self.device) # create criterion assert config.TRAIN.SEG_LOSS in ( 'focalloss2d', 'mixsoftmaxcrossentropyohemloss', 'mixsoftmaxcrossentropy'), 'cannot support {}'.format( config.TRAIN.SEG_LOSS) if config.TRAIN.SEG_LOSS == 'focalloss2d': self.criterion = get_loss(config.TRAIN.SEG_LOSS, gamma=2., use_weight=False, size_average=True, ignore_index=config.DATASET.IGNORE_INDEX) elif config.TRAIN.SEG_LOSS == 'mixsoftmaxcrossentropyohemloss': min_kept = int(config.TRAIN.BATCH_SIZE // len(config.GPUS) * config.DATASET.CROP_SIZE**2 // 16) self.criterion = get_loss(config.TRAIN.SEG_LOSS, min_kept=min_kept, ignore_index=-1).to(self.device) else: self.criterion = get_loss(config.TRAIN.SEG_LOSS, ignore_index=-1) self.gen_criterion = get_loss('mseloss') self.kl_criterion = get_loss('criterionkldivergence') # optimizer and lr scheduling self.optimizer = optim.SGD(self.seg_net.parameters(), lr=self.LR, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WEIGHT_DECAY) self.scheduler = WarmupPolyLR(self.optimizer, T_max=self.max_iter, warmup_factor=config.TRAIN.WARMUP_FACTOR, warmup_iters=config.TRAIN.WARMUP_ITERS, power=0.9) self.gen_optimizer = optim.SGD(self.generator.parameters(), lr=self.GENERATOR_LR, momentum=config.TRAIN.MOMENTUM, weight_decay=config.TRAIN.WEIGHT_DECAY) self.gen_scheduler = WarmupPolyLR( self.gen_optimizer, T_max=self.max_iter, warmup_factor=config.TRAIN.WARMUP_FACTOR, warmup_iters=config.TRAIN.WARMUP_ITERS, power=0.9) if config.TRAIN.MIXED_PRECISION: [self.seg_net, self.generator ], [self.optimizer, self.gen_optimizer ] = amp.initialize([self.seg_net, self.generator], [self.optimizer, self.gen_optimizer], opt_level=config.TRAIN.MIXED_OPT_LEVEL) self.dtype = torch.half else: self.dtype = torch.float if self.DISTRIBUTED: self.seg_net = torch.nn.parallel.DistributedDataParallel( self.seg_net, device_ids=[args.local_rank], output_device=args.local_rank) self.generator = torch.nn.parallel.DistributedDataParallel( self.generator, device_ids=[args.local_rank], output_device=args.local_rank) self.feature_extracted = torch.nn.parallel.DistributedDataParallel( self.feature_extracted, device_ids=[args.local_rank], output_device=args.local_rank) # evaluation metrics self.metric = SegmentationMetric(trainset.NUM_CLASS) self.config = config self.logger = logger self.seg_dir = os.path.join(self.config.TRAIN.SAVE_DIR, 'seg') ptutil.mkdir(self.seg_dir) self.generator_dir = os.path.join(self.config.TRAIN.SAVE_DIR, 'generator') ptutil.mkdir(self.generator_dir)