def __call__(self, transform=None, target_transform=None, batch_size=64, num_workers=0): trainset = ImageItemList(self.trainlist, self.labels2idxs, transform=transform[0], target_transform=target_transform) trainsampler = distributed.DistributedSampler(trainset) trainloader = DataLoader(trainset, batch_size=batch_size, pin_memory=True, drop_last=True, num_workers=num_workers, sampler=trainsampler) validset = ImageItemList(self.validlist, self.labels2idxs, transform=transform[1], target_transform=target_transform) validsampler = distributed.DistributedSampler(validset) validloader = DataLoader(validset, batch_size=batch_size, pin_memory=True, drop_last=True, num_workers=num_workers, sampler=validsampler) return trainloader, trainsampler, validloader, validsampler
def make_torch_dataloaders(train_dataset, test_dataset, rank, world_size, bs, num_workers=4, distrib=True, sync_valid=False): "make torch-based distributed dataloaders from torch compatible datasets" if distrib: train_sampler = th_distrib.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) train_loader = th_data.DataLoader( train_dataset, batch_size=bs, sampler=train_sampler, # shuffle=True, num_workers=num_workers, drop_last=True) if sync_valid: test_sampler = th_distrib.DistributedSampler( test_dataset, num_replicas=world_size, rank=rank, shuffle=False) test_loader = th_data.DataLoader( test_dataset, batch_size=bs, sampler=test_sampler, # shuffle=False, num_workers=num_workers, drop_last=True) else: test_loader = th_data.DataLoader(test_dataset, batch_size=bs, shuffle=False, num_workers=num_workers, drop_last=True) else: train_loader = th_data.DataLoader( train_dataset, batch_size=bs, # sampler=train_sampler, shuffle=True, num_workers=num_workers, drop_last=True) test_loader = th_data.DataLoader(test_dataset, batch_size=bs, shuffle=False, num_workers=num_workers, drop_last=True) dataloaders = DataLoaders(train_loader, test_loader, device=None) return dataloaders
def create_dataloader(config, data, mode): dataset = create_dataset(config, data, mode) if mode == 'train': # create Sampler if dist.is_available() and dist.is_initialized(): train_RandomSampler = distributed.DistributedSampler(dataset) else: train_RandomSampler = sampler.RandomSampler(dataset, replacement=False) train_BatchSampler = sampler.BatchSampler(train_RandomSampler, batch_size=config.train.batch_size, drop_last=config.train.dataloader.drop_last) # Augment collator = get_collate_fn(config) # DataLoader data_loader = DataLoader(dataset=dataset, batch_sampler=train_BatchSampler, collate_fn=collator, pin_memory=config.train.dataloader.pin_memory, num_workers=config.train.dataloader.work_nums) elif mode == 'val': if dist.is_available() and dist.is_initialized(): val_SequentialSampler = distributed.DistributedSampler(dataset) else: val_SequentialSampler = sampler.SequentialSampler(dataset) val_BatchSampler = sampler.BatchSampler(val_SequentialSampler, batch_size=config.val.batch_size, drop_last=config.val.dataloader.drop_last) data_loader = DataLoader(dataset, batch_sampler=val_BatchSampler, pin_memory=config.val.dataloader.pin_memory, num_workers=config.val.dataloader.work_nums) else: if dist.is_available() and dist.is_initialized(): test_SequentialSampler = distributed.DistributedSampler(dataset) else: test_SequentialSampler = None data_loader = DataLoader(dataset, sampler=test_SequentialSampler, batch_size=config.test.batch_size, pin_memory=config.val.dataloader.pin_memory, num_workers=config.val.dataloader.work_nums) return data_loader
def loaders_and_test_images(gpu, options): train_sampler = distutils.DistributedSampler( options.dataset, num_replicas=options.gpus, rank=gpu) if options.gpus > 1 else None train_loader = DataLoader( options.dataset, batch_size=options.batch_size * 2, num_workers=options. num_workers, # batch size * 2: get both anchors and negatives collate_fn=datautils.gp_annotated_collate_fn, pin_memory=True, shuffle=(options.gpus == 1), sampler=train_sampler) # not using distributed sampling here as it's just an auxilliary set for the discriminator and the selection of boxes is random anyway # might take it into use later though disc_loader = DiscriminatorLoader(options) test_images, gen_test_images = zip(*[ options.evaldata[img_index % len(options.evaldata)][:2] for img_index in options.sample_indices ]) test_images = torch.stack(test_images) gen_test_images = torch.stack(gen_test_images) return train_loader, disc_loader, train_sampler, test_images, gen_test_images
def __init__(self, dataset, batch_size, shuffle, num_workers=0, drop_last=False, sampler=None, half=None, dist_=None): self.dataset = dataset if half == 1: self.half = True else: self.half = False sampler = None if dist_ == 1: sampler = distributed.DistributedSampler(dataset) shuffle = False self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, sampler=sampler, drop_last=drop_last)
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False): if rect and shuffle: LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False') shuffle = False with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = LoadImagesAndLabels(path, imgsz, batch_size, augment=augment, # augmentation hyp=hyp, # hyperparameters rect=rect, # rectangular batches cache_images=cache, single_cls=single_cls, stride=int(stride), pad=pad, image_weights=image_weights, prefix=prefix) batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // DEVICE_COUNT, batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates return loader(dataset, batch_size=batch_size, shuffle=shuffle and sampler is None, num_workers=nw, sampler=sampler, pin_memory=True, collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
def get_train_loader_concat(conf, data_roots, sample_identity=False): extensions = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] total_class_num = 0 datasets = [] for root in data_roots: class_num, class_to_idx = find_classes(root) train_transform = trans.Compose([ trans.RandomHorizontalFlip(), trans.ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0), trans.ToTensor(), trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) path_ds = make_dataset(root, class_to_idx, extensions) for i, (url, label) in enumerate(path_ds): path_ds[i] = (url, label + total_class_num) datasets.extend(path_ds) total_class_num += class_num # logger.debug('datasets {}'.format(datasets)) image_ds = ImageDataset(datasets, train_transform) if sample_identity: train_sampler = DistRandomIdentitySampler(image_ds.dataset, conf.batch_size, conf.num_instances) else: train_sampler = distributed.DistributedSampler(image_ds) loader = DataLoader(image_ds, batch_size=conf.batch_size, shuffle=False, pin_memory=conf.pin_memory, num_workers=conf.num_workers, sampler = train_sampler) return loader, total_class_num
def get_inference_dataloader( dataset: Type[Dataset], transforms: Callable, batch_size: int = 16, num_workers: int = 8, pin_memory: bool = True, limit_num_samples: Optional[int] = None) -> DataLoader: if limit_num_samples is not None: np.random.seed(limit_num_samples) indices = np.random.permutation(len(dataset))[:limit_num_samples] dataset = Subset(dataset, indices) dataset = TransformedDataset(dataset, transform_fn=transforms) sampler = None if dist.is_available() and dist.is_initialized(): sampler = data_dist.DistributedSampler(dataset, shuffle=False) loader = DataLoader(dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers, sampler=sampler, pin_memory=pin_memory, drop_last=False) return loader
def get_train_loader(conf, data_mode, sample_identity=False): if data_mode == 'emore': root = conf.emore_folder/'imgs' elif data_mode == 'glint': root = conf.glint_folder/'imgs' else: logger.fatal('invalide data_mode {}'.format(data_mode)) exit(1) class_num, class_to_idx = find_classes(root) train_transform = trans.Compose([ trans.RandomHorizontalFlip(), trans.ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0), trans.ToTensor(), trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) extensions = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] path_ds = make_dataset(root, class_to_idx, extensions) dataset = ImageDataset(path_ds, train_transform) if sample_identity: train_sampler = DistRandomIdentitySampler(dataset.dataset, conf.batch_size, conf.num_instances) else: train_sampler = distributed.DistributedSampler(dataset) loader = DataLoader(dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=conf.pin_memory, num_workers=conf.num_workers, sampler = train_sampler) return loader, class_num
def test_distributedsampler(): # fake dataset dataset = list(range(12345)) num_proc = 8 rank = 0 sampler = datadist.DistributedSampler(dataset=dataset, num_replicas=num_proc, duplicate_last=True, shuffle=False)
def __init__(self, args): self.loader_train = None if not args.test_only: datasets = [] for d in args.data_train: module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' m = import_module('data.' + module_name.lower()) datasets.append(getattr(m, module_name)(args, name=d)) # add distributed training dataset sampler according to # https://github.com/horovod/horovod/blob/master/examples/pytorch/pytorch_imagenet_resnet50.py self.train_sampler = torchDis.DistributedSampler( # taking concatenated datasets MyConcatDataset(datasets), # number of processes participarting in distributed training num_replicas=hvd.size(), # Rank of the current process within num_replicas rank=hvd.rank()) self.loader_train = dataloader.DataLoader( # taking concatenated datasets MyConcatDataset(datasets), # how many samples per batch to load batch_size=args.batch_size, # data shuffling disabled when using sampler shuffle=False, pin_memory=not args.cpu, # threads used in GPU for data loading num_workers=args.n_threads, # added the new distributed sampler sampler=self.train_sampler) self.loader_test = [] for d in args.data_test: if d in ['Set5', 'Set14', 'B100', 'Urban100']: m = import_module('data.benchmark') testset = getattr(m, 'Benchmark')(args, train=False, name=d) else: module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' m = import_module('data.' + module_name.lower()) testset = getattr(m, module_name)(args, train=False, name=d) self.loader_test.append( dataloader.DataLoader( testset, batch_size=1, shuffle=False, pin_memory=not args.cpu, num_workers=args.n_threads, ))
def make_data_sampler(dataset, shuffle, is_distributed): if is_distributed: ## distributed.DistributedSampler is shuffle by default return distributed.DistributedSampler(dataset) if shuffle: return torch.utils.data.sampler.RandomSampler(dataset) else: return torch.utils.data.sampler.SequentialSampler(dataset)
def set_distributed_sampler(self:th_data.DataLoader, rank, world_size): 'replace sampler with torch distributed sampler' distrib_sampler = th_distrib.DistributedSampler(self.dataset, num_replicas=world_size, rank=rank, shuffle=True) self.sampler = distrib_sampler batch_sampler_klass = self.batch_sampler.__class__ self.batch_sampler = batch_sampler_klass(self.sampler, self.batch_size, self.drop_last)
def loader_and_test_img(gpu, options): test_image, _ = options.dataset[0] sampler = distutils.DistributedSampler( options.dataset, num_replicas=options.gpus, rank=gpu) if options.gpus > 1 else None loader = DataLoader(options.dataset, batch_size=options.batch_size, num_workers=options.num_workers, collate_fn=datautils.sku110k_collate_fn, pin_memory=True, shuffle=(options.gpus == 1), sampler=sampler) return loader, sampler, test_image
def get_train_val_loader(args): train_folder = os.path.join(args.data_folder, 'train') val_folder = os.path.join(args.data_folder, 'val') normalize = transforms.Normalize(mean=[(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2], std=[(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]) train_dataset = datasets.ImageFolder( train_folder, transforms.Compose([ transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.0)), transforms.RandomHorizontalFlip(), RGB2Lab(), transforms.ToTensor(), normalize, ])) val_dataset = datasets.ImageFolder( val_folder, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), RGB2Lab(), transforms.ToTensor(), normalize, ])) print('number of train: {}'.format(len(train_dataset))) print('number of val: {}'.format(len(val_dataset))) if args.distributed: # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) train_sampler = distributed.DistributedSampler(train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) return train_loader, val_loader, train_sampler
def run(gpu, config): cudnn.benchmark = True if config['distribute']: rank = config['rank'] * config['last_node_gpus'] + gpu print("world_size: {}, rank: {}".format(config['world_size'], rank)) dist.init_process_group(backend=config['backend'], init_method=config['ip'], world_size=config['world_size'], rank=rank) assert cudnn.enabled, "Amp requires cudnn backend to be enabled." torch.cuda.set_device(gpu) # create model model = AlexNet(10) # define loss function criterion = nn.CrossEntropyLoss() # define optimizer strategy optimizer = torch.optim.SGD(model.parameters(), config['lr'], momentum=config['momentum'], weight_decay=config['weight_decay']) # convert pytorch to apex model. apexparallel = ApexDistributeModel(model, criterion, optimizer, config, gpu) apexparallel.convert() apexparallel.lars() # load data data_path = '~/datasets/cifar10/train' train_set = LoadClassifyDataSets(data_path, 227) train_sampler = None if config['distribute']: train_sampler = distributed.DistributedSampler(train_set) train_loader = DataLoader(train_set, config['batch_size'], shuffle=(train_sampler is None), num_workers=config['num_workers'], pin_memory=True, sampler=train_sampler, collate_fn=collate_fn) for epo in range(config['epoch']): if config['distribute']: train_sampler.set_epoch(epo) # train for per epoch apexparallel.train(epo, train_loader)
def run(gpu, config): cudnn.benchmark = True if config['distribute']: rank = config['rank'] * config['last_node_gpus'] + gpu print("world_size: {}, rank: {}".format(config['world_size'], rank)) dist.init_process_group(backend=config['backend'], init_method=config['ip'], world_size=config['world_size'], rank=rank) assert cudnn.enabled, "Amp requires cudnn backend to be enabled." # create model model = AlexNet(10) if config['sync_bn']: # synchronization batch normal model = apex.parallel.convert_syncbn_model(model) torch.cuda.set_device(gpu) model = model.cuda(gpu) # define loss function criterion = nn.CrossEntropyLoss().cuda(gpu) # define optimizer strategy optimizer = torch.optim.SGD(model.parameters(), config['lr'], momentum=config['momentum'], weight_decay=config['weight_decay']) # initialization apex model, optimizer = apex.amp.initialize(model, optimizer, opt_level='O0') if config['distribute']: # model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True) # load data data_path = '~/datasets/cifar10/train' train_set = LoadClassifyDataSets(data_path, 227) train_sampler = None if config['distribute']: train_sampler = distributed.DistributedSampler(train_set) train_loader = DataLoader(train_set, config['batch_size'], shuffle=(train_sampler is None), num_workers=config['num_workers'], pin_memory=True, sampler=train_sampler, collate_fn=collate_fn) for epo in range(config['epoch']): if config['distribute']: train_sampler.set_epoch(epo) # train for per epoch train(train_loader, model, criterion, optimizer, epo, gpu)
def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs): dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), download=True) train_sampler = dist.DistributedSampler( dataset) if is_distributed else None return DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None, sampler=train_sampler, **kwargs)
def __init__(self, data_path, data_specf, batch_size, shuffle=True, in_same_size=True, fetch_fn=None, prep_fn=None, collate_fn=None): name = os.path.basename(data_path).split('.')[0] rank = int(os.environ.get('RANK', -1)) nworld = int(os.environ.get('WORLD_SIZE', 1)) class TData(Dataset): def __init__(self, verbose=True): with h5py.File(data_path, 'r') as f: # make sure that the first dimension is batch self.length = len(f[list(data_specf.keys())[0]]) if verbose: print('+' + (49 * '-') + '+') print('| \033[1;35m%-20s\033[0m fuel tank has been mounted |'% name) print('+' + (49 * '-') + '+') def _openH5(self): self.hdf5 = h5py.File(data_path, 'r') def __getitem__(self, idx: int): if not hasattr(self, 'hdf5'): self._openH5() item = fetch_fn(self.hdf5, idx) if prep_fn is not None: item = prep_fn(item) return item def __len__(self): return self.length self.name = name self.rank = rank self.tdata = TData(rank<=0) self.counter = 0 self.batch_size = batch_size if in_same_size: self.MPE = len(self.tdata) // (batch_size * nworld) else: self.MPE = ceil(len(self.tdata) / (batch_size * nworld)) ncpus = cpu_count() if rank >= 0: from torch.utils.data import distributed as dist self.sampler = dist.DistributedSampler(self.tdata) self.tloader = DataLoader(self.tdata, batch_size, sampler=self.sampler, collate_fn=collate_fn, drop_last=in_same_size, num_workers=ncpus) else: self.tloader = DataLoader(self.tdata, batch_size, shuffle, collate_fn=collate_fn, drop_last=in_same_size, num_workers=ncpus)
def main(): # Enable OBS access. mox.file.shift('os', 'mox') dist.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size) dataset = datasets.MNIST(args.data_url, train=True, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = distributed.DistributedSampler( dataset, num_replicas=args.world_size, rank=args.rank) data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, sampler=train_sampler) model = Net() if torch.cuda.is_available(): model = model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for epoch in range(10): epoch_loss = 0.0 train_sampler.set_epoch(epoch) for data, target in data_loader: optimizer.zero_grad() if torch.cuda.is_available(): data, target = (data.cuda(), target.cuda()) output = model(data) loss = F.nll_loss(output, target) epoch_loss += loss.data loss.backward() optimizer.step() print('epoch ', epoch, ' : ', epoch_loss / len(data_loader)) if args.train_url and dist.get_rank() == 0: torch.save(model.state_dict(), args.train_url + 'model.pt')
def build_dataloader(dataset, cfg, args, phase='train'): batch_size = cfg.batch_size if 'train' == phase else cfg.test_batch_size collect_fn = COLLECT_FN[ cfg.train_collect_fn] if 'train' == phase else COLLECT_FN[ cfg.test_collect_fn] worker_init_function = worker_init_fn if 'train' == phase else worker_init_test_fn shuffle = 'train' == phase num_workers = args.num_workers if args.local_rank != -1: sampler = data_dist.DistributedSampler(dataset) num_workers = args.num_workers // dist.get_world_size() return get_dpp_loader(dataset, sampler, batch_size, num_workers, collect_fn, worker_init_function) else: return get_loader(dataset, batch_size=batch_size, num_workers=num_workers, collect_fn=collect_fn, shuffle=shuffle, worker_init_function=worker_init_function)
def get_dataloader(args, seed): if args.rank == 0: print("===> Get Dataloader...") setattr(args, 'mode', 'train') train_dataset, batch_size = get_data(args=args) train_sampler = distributed.DistributedSampler( train_dataset, shuffle=True, seed=seed, num_replicas=args.world_size, rank=args.rank) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler, collate_fn=train_collate_fn) setattr(args, 'train_data', train_loader) dist.barrier(device_ids=[args.rank])
def get_train_loader_from_txt(conf, data_mode, sample_identity=False): if data_mode == 'emore': txt_path = conf.emore_folder/'imgs'/'train_list.txt' elif data_mode == 'glint': txt_path = conf.glint_folder/'imgs'/'train_list.txt' else: logger.fatal('invalide data_mode {}'.format(data_mode)) exit(1) train_transform = trans.Compose([ trans.RandomHorizontalFlip(), trans.ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0), trans.ToTensor(), trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) dataset = ImageLandmarkDataset(txt_path, train_transform) if sample_identity: train_sampler = DistRandomIdentitySampler(dataset.dataset, conf.batch_size, conf.num_instances) else: train_sampler = distributed.DistributedSampler(dataset) loader = DataLoader(dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=conf.pin_memory, num_workers=conf.num_workers, sampler = train_sampler) return loader, dataset.class_num
if opt.beam_width > 1 and opt.phase == "infer": if not multi_gpu or hvd.rank() == 0: logger.info(f"Beam Width {opt.beam_width}") seq2seq.decoder = TopKDecoder(seq2seq.decoder, opt.beam_width) if opt.phase == "train": # Prepare Train Data trans_data = TranslateData(pad_id) train_set = DialogDataset(opt.train_path, trans_data.translate_data, src_vocab, tgt_vocab, max_src_length=opt.max_src_length, max_tgt_length=opt.max_tgt_length) train_sampler = dist.DistributedSampler(train_set, num_replicas=hvd.size(), rank=hvd.rank()) \ if multi_gpu else None train = DataLoader(train_set, batch_size=opt.batch_size, shuffle=False if multi_gpu else True, sampler=train_sampler, drop_last=True, collate_fn=trans_data.collate_fn) dev_set = DialogDataset(opt.dev_path, trans_data.translate_data, src_vocab, tgt_vocab, max_src_length=opt.max_src_length, max_tgt_length=opt.max_tgt_length) dev_sampler = dist.DistributedSampler(dev_set, num_replicas=hvd.size(), rank=hvd.rank()) \
def get_distributed_data_loader(dataset, num_replicas, rank, train_batch_size=64, test_batch_size=64, seed=42, root_dir='data/'): if dataset == 'cifar10': transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) augment_transform = [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip() ] train_transform = torchvision.transforms.Compose( [*augment_transform, transform]) train_data = torchvision.datasets.CIFAR10(root_dir, train=True, download=True, transform=train_transform) train_sampler = distributed.DistributedSampler(train_data, seed=seed, shuffle=True, drop_last=True) train_loader = DataLoader(train_data, batch_size=train_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=train_sampler, drop_last=True) test_data = torchvision.datasets.CIFAR10(root_dir, train=False, download=True, transform=transform) test_sampler = distributed.DistributedSampler(test_data, shuffle=False, drop_last=False) test_loader = DataLoader(test_data, batch_size=test_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=test_sampler, drop_last=False) return train_loader, test_loader, train_sampler, test_sampler if dataset == 'cifar100': transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) augment_transform = [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip() ] train_transform = torchvision.transforms.Compose( [*augment_transform, transform]) train_data = torchvision.datasets.CIFAR100(root_dir, train=True, download=True, transform=train_transform) train_sampler = distributed.DistributedSampler(train_data, seed=seed, shuffle=True, drop_last=True) train_loader = DataLoader(train_data, batch_size=train_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=train_sampler, drop_last=True) test_data = torchvision.datasets.CIFAR100(root_dir, train=False, download=True, transform=transform) test_sampler = distributed.DistributedSampler(test_data, shuffle=False, drop_last=False) test_loader = DataLoader(test_data, batch_size=test_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=test_sampler) return train_loader, test_loader, train_sampler, test_sampler if dataset == 'vgg_cifar10': transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) augment_transform = [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip() ] train_transform = torchvision.transforms.Compose( [*augment_transform, transform]) train_data = torchvision.datasets.CIFAR10(root_dir, train=True, download=True, transform=train_transform) train_sampler = distributed.DistributedSampler(train_data, seed=seed, shuffle=True, drop_last=True) train_loader = DataLoader(train_data, batch_size=train_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=train_sampler, drop_last=True) test_data = torchvision.datasets.CIFAR10(root_dir, train=False, download=True, transform=transform) test_sampler = distributed.DistributedSampler(test_data, shuffle=False, drop_last=False) test_loader = DataLoader(test_data, batch_size=test_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=test_sampler) return train_loader, test_loader, train_sampler, test_sampler if dataset == 'vgg_cifar100': transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) augment_transform = [ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip() ] train_transform = torchvision.transforms.Compose( [*augment_transform, transform]) train_data = torchvision.datasets.CIFAR100(root_dir, train=True, download=True, transform=train_transform) train_sampler = distributed.DistributedSampler(train_data, seed=seed, shuffle=True, drop_last=True) train_loader = DataLoader(train_data, batch_size=train_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=train_sampler, drop_last=True) test_data = torchvision.datasets.CIFAR100(root_dir, train=False, download=True, transform=transform) test_sampler = distributed.DistributedSampler(test_data, shuffle=False, drop_last=False) test_loader = DataLoader(test_data, batch_size=test_batch_size, pin_memory=True, shuffle=False, num_workers=0, sampler=test_sampler) return train_loader, test_loader, train_sampler, test_sampler
def upsnet_train(): if is_master: logger.info('training config:{}\n'.format(pprint.pformat(config))) gpus = [torch.device('cuda', int(_)) for _ in config.gpus.split(',')] num_replica = hvd.size() if config.train.use_horovod else len(gpus) num_gpus = 1 if config.train.use_horovod else len(gpus) # create models train_model = eval(config.symbol)().cuda() # create optimizer params_lr = train_model.get_params_lr() # we use custom optimizer and pass lr=1 to support different lr for different weights optimizer = SGD(params_lr, lr=1, momentum=config.train.momentum, weight_decay=config.train.wd) if config.train.use_horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=train_model.named_parameters()) optimizer.zero_grad() # create data loader train_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.image_set.split('+'), flip=config.train.flip, result_path=final_output_path) val_dataset = eval(config.dataset.dataset)(image_sets=config.dataset.test_image_set.split('+'), flip=False, result_path=final_output_path, phase='val') if config.train.use_horovod: train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) val_sampler = distributed.DistributedSampler(val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, sampler=train_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, sampler=val_sampler, num_workers=num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) else: train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=config.train.shuffle, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=train_dataset.collate) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.train.batch_size, shuffle=False, num_workers=num_gpus * 4 if not config.debug_mode else num_gpus * 4, drop_last=False, collate_fn=val_dataset.collate) # preparing curr_iter = config.train.begin_iteration batch_end_callback = [Speedometer(num_replica * config.train.batch_size, config.train.display_iter)] metrics = [] metrics_name = [] if config.network.has_rpn: metrics.extend([AvgMetric(name='rpn_cls_loss'), AvgMetric(name='rpn_bbox_loss'),]) metrics_name.extend(['rpn_cls_loss', 'rpn_bbox_loss']) if config.network.has_rcnn: metrics.extend([AvgMetric(name='rcnn_accuracy'), AvgMetric(name='cls_loss'), AvgMetric(name='bbox_loss'),]) metrics_name.extend(['rcnn_accuracy', 'cls_loss', 'bbox_loss']) if config.network.has_mask_head: metrics.extend([AvgMetric(name='mask_loss'), ]) metrics_name.extend(['mask_loss']) if config.network.has_fcn_head: metrics.extend([AvgMetric(name='fcn_loss'), ]) metrics_name.extend(['fcn_loss']) if config.train.fcn_with_roi_loss: metrics.extend([AvgMetric(name='fcn_roi_loss'), ]) metrics_name.extend(['fcn_roi_loss']) if config.network.has_panoptic_head: metrics.extend([AvgMetric(name='panoptic_accuracy'), AvgMetric(name='panoptic_loss'), ]) metrics_name.extend(['panoptic_accuracy', 'panoptic_loss']) if config.train.resume: train_model.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')), resume=True) optimizer.load_state_dict(torch.load(os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) else: if is_master: train_model.load_state_dict(torch.load(config.network.pretrained)) if config.train.use_horovod: hvd.broadcast_parameters(train_model.state_dict(), root_rank=0) if not config.train.use_horovod: train_model = DataParallel(train_model, device_ids=[int(_) for _ in config.gpus.split(',')]).to(gpus[0]) if is_master: batch_end_callback[0](0, 0) train_model.eval() # start training while curr_iter < config.train.max_iteration: if config.train.use_horovod: train_sampler.set_epoch(curr_iter) if config.network.use_syncbn: train_model.train() if config.network.backbone_freeze_at > 0: train_model.freeze_backbone(config.network.backbone_freeze_at) if config.network.backbone_fix_bn: train_model.resnet_backbone.eval() for inner_iter, batch in enumerate(train_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda() for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda() lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() output = train_model(data, label) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() * config.train.bbox_loss_weight if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(allreduce_async(loss, name='train_total_loss')) for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l)) loss = hvd.synchronize(losses[0]).item() if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = hvd.synchronize(losses[i + 1]).item() if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): if 'momentum_buffer' in optimizer.state_dict()['state'][k]: optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) else: inner_iter = 0 train_iterator = train_loader.__iter__() while inner_iter + num_gpus <= len(train_loader): batch = [] for gpu_id in gpus: data, label, _ = train_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 lr = adjust_learning_rate(optimizer, curr_iter, config) optimizer.zero_grad() if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) loss = 0 if config.network.has_rpn: loss = loss + output['rpn_cls_loss'].mean() + output['rpn_bbox_loss'].mean() if config.network.has_rcnn: loss = loss + output['cls_loss'].mean() + output['bbox_loss'].mean() if config.network.has_mask_head: loss = loss + output['mask_loss'].mean() if config.network.has_fcn_head: loss = loss + output['fcn_loss'].mean() * config.train.fcn_loss_weight if config.train.fcn_with_roi_loss: loss = loss + output['fcn_roi_loss'].mean() * config.train.fcn_loss_weight * 0.2 if config.network.has_panoptic_head: loss = loss + output['panoptic_loss'].mean() * config.train.panoptic_loss_weight loss.backward() optimizer.step(lr) losses = [] losses.append(loss.item()) for l in metrics_name: losses.append(output[l].mean().item()) loss = losses[0] if is_master: writer.add_scalar('train_total_loss', loss, curr_iter) for i, (metric, l) in enumerate(zip(metrics, metrics_name)): loss = losses[i + 1] if is_master: writer.add_scalar('train_' + l, loss, curr_iter) metric.update(_, _, loss) curr_iter += 1 if curr_iter in config.train.decay_iteration: if is_master: logger.info('decay momentum buffer') for k in optimizer.state_dict()['state'].keys(): optimizer.state_dict()['state'][k]['momentum_buffer'].div_(10) if is_master: if curr_iter % config.train.display_iter == 0: for callback in batch_end_callback: callback(curr_iter, metrics) if curr_iter % config.train.snapshot_step == 0: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix+str(curr_iter)+'.state.pth')) while True: try: train_iterator.next() except: break for metric in metrics: metric.reset() if config.train.eval_data: train_model.eval() if config.train.use_horovod: for inner_iter, batch in enumerate(val_loader): data, label, _ = batch for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.cuda(non_blocking=True) with torch.no_grad(): output = train_model(data, label) for metric, l in zip(metrics, metrics_name): loss = hvd.allreduce(output[l].mean()).item() if is_master: metric.update(_, _, loss) else: inner_iter = 0 val_iterator = val_loader.__iter__() while inner_iter + len(gpus) <= len(val_loader): batch = [] for gpu_id in gpus: data, label, _ = val_iterator.next() for k, v in data.items(): data[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) for k, v in label.items(): label[k] = v if not torch.is_tensor(v) else v.pin_memory().to(gpu_id, non_blocking=True) batch.append((data, label)) inner_iter += 1 with torch.no_grad(): if config.train.use_horovod: output = train_model(data, label) else: output = train_model(*batch) losses = [] for l in metrics_name: losses.append(allreduce_async(output[l].mean(), name=l) if config.train.use_horovod else output[l].mean().item()) for metric, loss in zip(metrics, losses): loss = hvd.synchronize(loss).item() if config.train.use_horovod else loss if is_master: metric.update(_, _, loss) while True: try: val_iterator.next() except Exception: break s = 'Batch [%d]\t Epoch[%d]\t' % (curr_iter, curr_iter // len(train_loader)) for metric in metrics: m, v = metric.get() s += 'Val-%s=%f,\t' % (m, v) if is_master: writer.add_scalar('val_' + m, v, curr_iter) metric.reset() if is_master: logger.info(s) if is_master and config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth')) elif not config.train.use_horovod: logger.info('taking snapshot ...') torch.save(train_model.module.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.pth')) torch.save(optimizer.state_dict(), os.path.join(final_output_path, config.model_prefix + str(curr_iter) + '.state.pth'))
def run(gpu_id, config): cudnn.benchmark = True if config['distribute']: rank = config['rank'] * config['last_node_gpus'] + gpu_id gpu_id = gpu_id + config['start_node_gpus'] print("world_size: {}, rank: {}, gpu: {}".format(config['world_size'], rank, gpu_id)) dist.init_process_group(backend=config['backend'], init_method=config['ip'], world_size=config['world_size'], rank=rank) assert cudnn.enabled, "Amp requires cudnn backend to be enabled." torch.manual_seed(42) # create model model = DarknetClassify(darknet53()) # define loss function criterion = nn.CrossEntropyLoss() # define optimizer strategy optimizer = torch.optim.SGD(model.parameters(), config['lr'], momentum=config['momentum'], weight_decay=config['weight_decay']) # define lr strategy # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=range(0, config['epochs'], config['epochs'] // 4), gamma=0.5) if config.get('resume_path'): loc = 'cuda:{}'.format(gpu_id) checkpoint = torch.load(config.get('resume_path'), map_location=loc) model.load_state_dict(checkpoint['state_dict']) # convert pytorch to apex model. parallel = ConvertModel(model, criterion, optimizer, config, gpu_id) parallel.convert() # load data train_sets, val_sets = load_imagenet_data(config['data_path'], ) train_sampler = None if config['distribute']: train_sampler = distributed.DistributedSampler(train_sets) train_loader = DataLoader(train_sets, config['batch_size'], shuffle=(train_sampler is None), num_workers=config['num_workers'], pin_memory=True, sampler=train_sampler) val_loader = DataLoader(val_sets, config['batch_size'], shuffle=False, num_workers=config['num_workers'], pin_memory=True) dist.barrier() best_acc1 = 0 # start training for epoch in range(config['epochs']): if config['distribute']: train_sampler.set_epoch(epoch) loss = parallel.train(epoch, train_loader) dist.barrier() lr = parallel.get_lr() if rank == 0: # train for per epoch print('Epoch: [{}/{}], Lr: {:.8f}'.format(epoch, config['epochs'], lr)) # evaluate on validation set acc1, acc5 = validate(val_loader, model, criterion, gpu_id) if config['record']: with open('record.log', 'a') as f: f.write('Epoch {}, lr {:.8f}, loss: {:.8f}, Acc@1 {:.8f}, Acc5@ {:.8f} \n'. format(epoch, lr, loss, acc1, acc5)) is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) # remember best acc@1 and save checkpoint if is_best: torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, config['model_path'].format(epoch, best_acc1, loss)) scheduler.step(loss) dist.barrier() dist.destroy_process_group()
def get_train_val_loaders(root_path: str, train_transforms: Callable, val_transforms: Callable, batch_size: int = 16, num_workers: int = 8, val_batch_size: Optional[int] = None, pin_memory: bool = True, random_seed: Optional[int] = None, train_sampler: Optional[Union[Sampler, str]] = None, val_sampler: Optional[Union[Sampler, str]] = None, with_sbd: Optional[str] = None, limit_train_num_samples: Optional[int] = None, limit_val_num_samples: Optional[int] = None) -> Tuple[DataLoader, DataLoader, DataLoader]: train_ds = get_train_dataset(root_path) val_ds = get_val_dataset(root_path) if with_sbd is not None: sbd_train_ds = get_train_noval_sbdataset(with_sbd) train_ds = ConcatDataset([train_ds, sbd_train_ds]) if random_seed is not None: np.random.seed(random_seed) if limit_train_num_samples is not None: train_indices = np.random.permutation(len(train_ds))[:limit_train_num_samples] train_ds = Subset(train_ds, train_indices) if limit_val_num_samples is not None: val_indices = np.random.permutation(len(val_ds))[:limit_val_num_samples] val_ds = Subset(val_ds, val_indices) # random samples for evaluation on training dataset if len(val_ds) < len(train_ds): train_eval_indices = np.random.permutation(len(train_ds))[:len(val_ds)] train_eval_ds = Subset(train_ds, train_eval_indices) else: train_eval_ds = train_ds train_ds = TransformedDataset(train_ds, transform_fn=train_transforms) val_ds = TransformedDataset(val_ds, transform_fn=val_transforms) train_eval_ds = TransformedDataset(train_eval_ds, transform_fn=val_transforms) if isinstance(train_sampler, str): assert train_sampler == 'distributed' train_sampler = data_dist.DistributedSampler(train_ds) if isinstance(val_sampler, str): assert val_sampler == 'distributed' val_sampler = data_dist.DistributedSampler(val_ds, shuffle=False) train_loader = DataLoader(train_ds, shuffle=train_sampler is None, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, pin_memory=pin_memory, drop_last=True) val_batch_size = batch_size * 4 if val_batch_size is None else val_batch_size val_loader = DataLoader(val_ds, shuffle=False, sampler=val_sampler, batch_size=val_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=False) train_eval_loader = DataLoader(train_eval_ds, shuffle=False, sampler=val_sampler, batch_size=val_batch_size, num_workers=num_workers, pin_memory=pin_memory, drop_last=False) return train_loader, val_loader, train_eval_loader
def main(args): utils.init_distributed_mode(args) device = torch.device(args.gpus) in_chns = 3 if args.vision_type == 'monochromat': in_chns = 1 elif 'dichromat' in args.vision_type: in_chns = 2 data_reading_kwargs = { 'target_size': args.target_size, 'colour_vision': args.vision_type, 'colour_space': args.colour_space } dataset, num_classes = utils.get_dataset(args.dataset, args.data_dir, 'train', **data_reading_kwargs) json_file_name = os.path.join(args.out_dir, 'args.json') with open(json_file_name, 'w') as fp: json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) dataset_test, _ = utils.get_dataset(args.dataset, args.data_dir, 'val', **data_reading_kwargs) if args.distributed: train_sampler = torch_dist.DistributedSampler(dataset) test_sampler = torch_dist.DistributedSampler(dataset_test) else: train_sampler = torch_data.RandomSampler(dataset) test_sampler = torch_data.SequentialSampler(dataset_test) data_loader = torch_data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, collate_fn=utils.collate_fn, drop_last=True) data_loader_test = torch_data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) if args.network_name == 'unet': model = segmentation_models.unet.model.Unet( encoder_weights=args.backbone, classes=num_classes) if args.pretrained: print('Loading %s' % args.pretrained) checkpoint = torch.load(args.pretrained, map_location='cpu') remove_keys = [] for key_ind, key in enumerate(checkpoint['state_dict'].keys()): if 'segmentation_head' in key: remove_keys.append(key) for key in remove_keys: del checkpoint['state_dict'][key] model.load_state_dict(checkpoint['state_dict'], strict=False) elif args.custom_arch: print('Custom model!') backbone_name, customs = model_utils.create_custom_resnet( args.backbone, None) if customs is not None: args.backbone = {'arch': backbone_name, 'customs': customs} model = custom_models.__dict__[args.network_name]( args.backbone, num_classes=num_classes, aux_loss=args.aux_loss) if args.pretrained: print('Loading %s' % args.pretrained) checkpoint = torch.load(args.pretrained, map_location='cpu') num_all_keys = len(checkpoint['state_dict'].keys()) remove_keys = [] for key_ind, key in enumerate(checkpoint['state_dict'].keys()): if key_ind > (num_all_keys - 3): remove_keys.append(key) for key in remove_keys: del checkpoint['state_dict'][key] pretrained_weights = OrderedDict( (k.replace('segmentation_model.', ''), v) for k, v in checkpoint['state_dict'].items()) model.load_state_dict(pretrained_weights, strict=False) else: model = seg_models.__dict__[args.network_name]( num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) best_iou = 0 model_progress = [] model_progress_path = os.path.join(args.out_dir, 'model_progress.csv') # loading the model if to eb resumed if args.resume is not None: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) best_iou = checkpoint['best_iou'] # if model progress exists, load it if os.path.exists(model_progress_path): model_progress = np.loadtxt(model_progress_path, delimiter=',') model_progress = model_progress.tolist() master_model = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpus]) master_model = model.module if args.network_name == 'unet': params_to_optimize = model.parameters() else: params_to_optimize = [ { 'params': [ p for p in master_model.backbone.parameters() if p.requires_grad ] }, { 'params': [ p for p in master_model.classifier.parameters() if p.requires_grad ] }, ] if args.aux_loss: params = [ p for p in master_model.aux_classifier.parameters() if p.requires_grad ] params_to_optimize.append({'params': params, 'lr': args.lr * 10}) optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_lambda = lambda x: (1 - x / (len(data_loader) * args.epochs))**0.9 lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) criterion = select_criterion(args.dataset) start_time = time.time() for epoch in range(args.initial_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_log = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) val_confmat = utils.evaluate(model, data_loader_test, device=device, num_classes=num_classes) val_log = val_confmat.get_log_dict() is_best = val_log['iou'] > best_iou best_iou = max(best_iou, val_log['iou']) model_data = { 'epoch': epoch + 1, 'arch': args.network_name, 'customs': { 'aux_loss': args.aux_loss, 'pooling_type': args.pooling_type, 'in_chns': in_chns, 'num_classes': num_classes, 'backbone': args.backbone }, 'state_dict': master_model.state_dict(), 'optimizer': optimizer.state_dict(), 'target_size': args.target_size, 'args': args, 'best_iou': best_iou, } utils.save_on_master(model_data, os.path.join(args.out_dir, 'checkpoint.pth')) if is_best: utils.save_on_master(model_data, os.path.join(args.out_dir, 'model_best.pth')) epoch_prog, header = add_to_progress(train_log, [], '') epoch_prog, header = add_to_progress(val_log, epoch_prog, header, prefix='v_') model_progress.append(epoch_prog) np.savetxt(model_progress_path, np.array(model_progress), delimiter=';', header=header, fmt='%s') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str))
def worker(pid, ngpus_per_node, args): """ Note: Until platform setting, everything runs on CPU side. """ # -------------------------------------------------------- # # Initialization s# # -------------------------------------------------------- # configs = {} with open(args.config, "r") as json_config: configs = json.load(json_config) # NOTE: # -- For distributed data parallel, we use 1 process for 1 GPU and # set GPU ID = Process ID # -- For data parallel, we only has 1 process and the master GPU is # GPU 0 args.gid = pid if pid is not None: torch.cuda.set_device(args.gid) if args.distributed: args.rank = args.rank * ngpus_per_node + args.gid dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) if args.distributed: print("Proc [{:2d}], Rank [{:2d}], Uses GPU [{:2d}]".format( pid, args.rank, args.gid ) ) global best_prec1 start_epoch = 0 checkpoint = None # -------------------------------------------------------- # # Construct Datasets & Dataloaders # # -------------------------------------------------------- # # construct training set print("Proc [{:2d}] constructing training set...".format(pid)) train_transforms = [] for _t in configs["train_transforms"]: train_transforms.append(cfgs.config2transform(_t)) train_transform = torchstream.transforms.Compose( transforms=train_transforms ) configs["train_dataset"]["argv"]["transform"] = train_transform train_dataset = cfgs.config2dataset(configs["train_dataset"]) # TODO: integrate into configuration? if args.distributed: train_sampler = datadist.DistributedSampler(train_dataset, shuffle=True) else: train_sampler = None if args.distributed: configs["train_loader"]["batch_size"] = \ int(configs["train_loader"]["batch_size"] / args.world_size) configs["train_loader"]["num_workers"] = \ int(configs["train_loader"]["num_workers"] / args.world_size) # turn off the shuffle option outside, set shuffule in sampler configs["train_loader"]["shuffle"] = False configs["train_loader"]["dataset"] = train_dataset configs["train_loader"]["sampler"] = train_sampler train_loader = cfgs.config2dataloader(configs["train_loader"]) # construct validation set print("Proc [{:2d}] constructing validation set...".format(pid)) val_transforms = [] for _t in configs["val_transforms"]: val_transforms.append(cfgs.config2transform(_t)) val_transform = torchstream.transforms.Compose( transforms=val_transforms ) configs["val_dataset"]["argv"]["transform"] = val_transform val_dataset = cfgs.config2dataset(configs["val_dataset"]) if args.distributed: val_sampler = datadist.DistributedSampler(val_dataset, shuffle=False) else: val_sampler = None if args.distributed: configs["val_loader"]["batch_size"] = \ int(configs["val_loader"]["batch_size"] / args.world_size) configs["val_loader"]["num_workers"] = \ int(configs["val_loader"]["num_workers"] / args.world_size) configs["val_loader"]["dataset"] = val_dataset configs["val_loader"]["sampler"] = val_sampler val_loader = cfgs.config2dataloader(configs["val_loader"]) # -------------------------------------------------------- # # Construct Neural Network # # -------------------------------------------------------- # model = cfgs.config2model(configs["model"]) # load checkpoint if "resume" in configs["train"]: # NOTE: the 1st place to load checkpoint resume_config = configs["train"]["resume"] checkpoint = utils.load_checkpoint(**resume_config) if checkpoint is None: print("Load Checkpoint Failed") if checkpoint is not None: # check checkpoint device mapping model_state_dict = checkpoint["model_state_dict"] print("Loading Checkpoint...") model.load_state_dict(model_state_dict) # ignore finetune if there is a checkpoint if (checkpoint is None) and ("finetune" in configs["train"]): finetune_config = configs["train"]["finetune"] checkpoint = utils.load_checkpoint(**finetune_config) if checkpoint is None: raise ValueError("Load Finetune Model Failed") # TODO: move load finetune model into model's method # not all models replace FCs only model_state_dict = checkpoint["model_state_dict"] for key in model_state_dict: if "fc" in key: # use FC from new network print("Replacing ", key) model_state_dict[key] = model.state_dict()[key] model.load_state_dict(model_state_dict) # set to None to prevent loading other states checkpoint = None # move to device model = model.cuda(args.gid) if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gid], find_unused_parameters=True ) else: model = torch.nn.DataParallel(model) # -------------------------------------------------------- # # Construct Optimizer, Scheduler etc # # -------------------------------------------------------- # print("Setting Optimizer & Lr Scheduler...") if configs["optimizer"]["argv"]["params"] == "model_specified": print("Use Model Specified Training Policies") configs["optimizer"]["argv"]["params"] = \ model.module.get_optim_policies() else: print("Train All Parameters") configs["optimizer"]["argv"]["params"] = model.parameters() optimizer = cfgs.config2optimizer(configs["optimizer"]) lr_scheduler = cfgs.config2lrscheduler(optimizer, configs["lr_scheduler"]) if "resume" in configs["train"]: if checkpoint is not None: best_prec1 = checkpoint["best_prec1"] start_epoch = checkpoint["epoch"] + 1 model_state_dict = checkpoint["model_state_dict"] optimizer_state_dict = checkpoint["optimizer_state_dict"] lr_scheduler_state_dict = checkpoint["lr_scheduler_state_dict"] optimizer.load_state_dict(optimizer_state_dict) lr_scheduler.load_state_dict(lr_scheduler_state_dict) print("Resume from epoch [{}], best prec1 [{}]". format(start_epoch - 1, best_prec1)) criterion = cfgs.config2criterion(configs["criterion"]) criterion = criterion.cuda(args.gid) # -------------------------------------------------------- # # Main Loop # # -------------------------------------------------------- # backup_config = None if "backup" in configs["train"]: backup_config = configs["train"]["backup"] epochs = configs["train"]["epochs"] print("Training Begins") for epoch in range(start_epoch, epochs): if args.distributed: train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) # train for one epoch train(gid=args.gid, loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, lr_scheduler=lr_scheduler, epoch=epoch) # evaluate on validation set prec1 = validate(gid=args.gid, loader=val_loader, model=model, criterion=criterion, epoch=epoch) # aproxiamation in distributed mode # currently, each process has the same number of samples (via padding) # so we directly apply `all_reduce` without weights if args.distributed: prec1_tensor = torch.Tensor([prec1]).cuda(args.gid) # print("[{}] Before reduce: {}".format(pid, prec1_tensor)) dist.all_reduce(prec1_tensor) # print("[{}] after reduce: {}".format(pid, prec1_tensor)) prec1 = prec1_tensor.item() / args.world_size # remember best prec@1 if (not args.distributed) or (args.rank == 0): print("*" * 80) print("Final Prec1: {:5.3f}".format(prec1)) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) print("Best Prec@1: %.3f" % (best_prec1)) print("*" * 80) # save checkpoint at rank 0 (not process 0, NFS!!!) if args.rank == 0: if backup_config is not None: dir_path = backup_config["dir_path"] pth_name = backup_config["pth_name"] model_state_dict = model.state_dict() optimizer_state_dict = optimizer.state_dict() lr_scheduler_state_dict = lr_scheduler.state_dict() # remove prefixes in (distributed) data parallel wrapper utils.checkpoint.remove_prefix_in_keys(model_state_dict) checkpoint = { "epoch": epoch, "model_state_dict": model_state_dict, "optimizer_state_dict": optimizer.state_dict(), "lr_scheduler_state_dict": lr_scheduler.state_dict(), "best_prec1": best_prec1 } utils.save_checkpoint(checkpoint=checkpoint, is_best=is_best, dir_path=dir_path, pth_name=pth_name)