def create_dataloader(config, data, mode): dataset = create_dataset(config, data, mode) if mode == 'train': # create Sampler if dist.is_available() and dist.is_initialized(): train_RandomSampler = distributed.DistributedSampler(dataset) else: train_RandomSampler = sampler.RandomSampler(dataset, replacement=False) train_BatchSampler = sampler.BatchSampler(train_RandomSampler, batch_size=config.train.batch_size, drop_last=config.train.dataloader.drop_last) # Augment collator = get_collate_fn(config) # DataLoader data_loader = DataLoader(dataset=dataset, batch_sampler=train_BatchSampler, collate_fn=collator, pin_memory=config.train.dataloader.pin_memory, num_workers=config.train.dataloader.work_nums) elif mode == 'val': if dist.is_available() and dist.is_initialized(): val_SequentialSampler = distributed.DistributedSampler(dataset) else: val_SequentialSampler = sampler.SequentialSampler(dataset) val_BatchSampler = sampler.BatchSampler(val_SequentialSampler, batch_size=config.val.batch_size, drop_last=config.val.dataloader.drop_last) data_loader = DataLoader(dataset, batch_sampler=val_BatchSampler, pin_memory=config.val.dataloader.pin_memory, num_workers=config.val.dataloader.work_nums) else: if dist.is_available() and dist.is_initialized(): test_SequentialSampler = distributed.DistributedSampler(dataset) else: test_SequentialSampler = None data_loader = DataLoader(dataset, sampler=test_SequentialSampler, batch_size=config.test.batch_size, pin_memory=config.val.dataloader.pin_memory, num_workers=config.val.dataloader.work_nums) return data_loader
def setup_data(self): transform = unrel.TRANSFORM # Initialize trainset self.trainset = data.Dataset(split='train', pairs='annotated', transform=transform) if self.opts.train_size: print('Using subset of %d from train_set' % self.opts.train_size) batch_sampler = sampler.SequentialSampler( range(self.opts.train_size)) else: batch_sampler = None self.trainloader = data.FauxDataLoader(self.trainset, sampler=batch_sampler, batch_size=self.opts.batch_size) # Initialize testset if self.opts.do_validation: self.testset = data.Dataset(split='test', pairs='annotated', transform=transform) batch_sampler = sampler.BatchSampler( sampler.SequentialSampler(self.testset), self.opts.test_batch_size, False ) # make test set load without shuffling so that we can use Tyler's RecallEvaluator self.testloaders = [ data.FauxDataLoader(self.testset, sampler=batch_sampler) ] else: print('No testset') self.testloaders = []
def __iter__(self): for bucket in self.bucket_sampler: sorted_sampler = SortedSampler(bucket, self.sort_key) for batch in samplers.SubsetRandomSampler( list( samplers.BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): yield [bucket[i] for i in batch]
def __init__(self, dataset, batch_size, shuffle = True, drop_last = False): # buckets list 根据contexts长度分组 self.buckets = bucket(dataset) # 打乱 list if shuffle: np.random.shuffle(self.buckets) random_samplers = [sampler.RandomSampler(bucket) for bucket in self.buckets] else: random_samplers = [sampler.SequentialSampler(bucket) for bucket in self.buckets] self.sampler = [sampler.BatchSampler(s, batch_size, drop_last) for s in random_samplers]
def _init_sampler(self, **kwargs): sampler = kwargs.get('sampler') batch_size = kwargs.get('batch_size', 1) drop_last = kwargs.get('drop_last', False) if isinstance(sampler, torchsampler.BatchSampler): return sampler if sampler == None: sampler = torchsampler.RandomSampler(self.dataset) elif not isinstance(sampler, torchsampler.Sampler): sampler = torchsampler.RandomSampler(sampler) return torchsampler.BatchSampler(sampler, batch_size, drop_last)
def __init__( self, sampler, batch_size, drop_last, sort_key, bucket_size_multiplier=100, ): super().__init__(sampler, batch_size, drop_last) self.sort_key = sort_key self.bucket_sampler = samplers.BatchSampler( sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False)
def create_training_batch(train_data, batch_size): ''' ''' students = train_data.keys() stud_ids = [stud_id for stud_id in students] batches = list( sampler.BatchSampler(sampler.SequentialSampler(stud_ids), batch_size=batch_size, drop_last=False)) batch_ids = [] for batch in batches: batch_ids.append([stud_ids[i] for i in batch]) return batch_ids
def __init__(self, input_tensor, input_lengths, labels_tensor, batch_size, sequence_lenght=2665): self.input_tensor = input_tensor self.input_lengths = input_lengths self.labels_tensor = labels_tensor self.batch_size = batch_size self.sequence_length = 2665 self.sampler = splr.BatchSampler( splr.RandomSampler(self.labels_tensor), self.batch_size, False) self.sampler_iter = iter(self.sampler)
def test_wegith_sampler(): from torch.utils.data import sampler weight = list([ 1, ] * 30) weight[:10] = list([ 3, ] * 10) weight_sampler = sampler.WeightedRandomSampler(weight, num_samples=len(weight)) batch_sampler = sampler.BatchSampler(weight_sampler, batch_size=4, drop_last=False) for indices in batch_sampler: print(indices)
def train_network(self,game_counter): if game_counter <=self.trajectory_size-1: ub_index = game_counter * 30 - 1 else: ub_index = self.trajectory_size*30 indeces = list(sampler.BatchSampler(sampler.SubsetRandomSampler(range(ub_index)), batch_size=300, drop_last=False)) i = 0 for index in indeces: if i >= 1: return self.optimizer.zero_grad() z = self.trajectory.outcome_values[index] prob = self.trajectory.mc_probs[index] v,net_prob = self.learning_network(self.trajectory.states[index]) cross_entropy_loss =self.cross_entropy(net_prob,prob) loss = ((z - v).pow(2) + cross_entropy_loss).mean() loss.backward() self.optimizer.step() i +=1 print(loss.item())
def get_image_dataloader(mode='train', coco_set=2014, images_path=os.environ[ 'HOME'] + '/Database/coco/images/', vocab_path='data/processed/coco_vocab.pkl', captions_path='data/processed/coco_captions.csv', batch_size=32, max_len=30, embedding_size=2048, num_captions=5, load_features=False, load_captions=False, preload=False, model='resnet152', num_workers=0): """ Generate a dataloader with the specified parameters. Args: mode: Dataset type to load coco_set: COCO dataset year to load images_path: Path to COCO dataset images vocab_path: Path to COCO vocab file caption_size: Path to captions vocab file batch_size: Batch size for Dataloader max_len: Max caption length embedding_size: Size of image embedding num_captions: Number of captions per image in dataset load_features: Boolean for creating or loading image features load_captions: Boolean for creating or loading image captions preload: Boolean for either preloading data into RAM during construction model: base model for encoderCNN num_workers: Dataloader parameter Return: data_loader: A torch dataloader for the specified coco dataset """ # Ensure that specified mode is valid try: assert mode in ['train', 'val', 'test'] assert coco_set in [2014, 2017] assert os.path.exists(images_path) assert os.path.exists(vocab_path) assert os.path.exists(captions_path) except AssertionError: # Defaulting conditions if mode not in ['train', 'val', 'test']: print('Invalid mode specified: ' + '{}. Defaulting to val mode'.format(mode)) mode = 'val' if coco_set not in [2014, 2017]: print('Invalid coco year specified: ' + '{}. Defaulting to 2014'.format(coco_set)) coco_set = 2014 # Terminating conditions if not os.path.exists(images_path): print(images_path + " does not exist!") return None elif not os.path.exists(vocab_path): print(vocab_path + " does not exist!") return None elif not os.path.exists(captions_path): print(captions_path + " does not exist!") return None # Generate dataset data = ImageDataset(mode, coco_set, images_path, vocab_path, captions_path, batch_size, max_len, embedding_size, num_captions, load_features, load_captions, preload, model) # Create a dataloader -- only randomly sample when # training if mode == 'train': # Get all possible image indices indices = data.get_indices() # Initialize a sampler for the indices init_sampler = sampler.SubsetRandomSampler( indices=indices) # Create data loader with dataset and sampler data_loader = DataLoader(dataset=data, num_workers=num_workers, batch_sampler=sampler.BatchSampler( sampler=init_sampler, batch_size=batch_size, drop_last=False)) else: data_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=True, num_workers=num_workers) return data_loader
def __init__(self, storage, sampler=None, num_batches=10, batch_size=None, batch_size_bounds=None, replacement=True, verbose=0): """ Initialize the storage sampler. Args: storage (Storage): storage sampler. sampler (Sampler, None): If None, it will use a sampler that randomly sample batches of the storage. It will by default sample :attr:`num_batches`. num_batches (int): number of batches. batch_size (int, None): size of the batch. If None, it will be computed based on the size of the storage, where batch_size = size(storage) // num_batches. Note that the batch size must be smaller than the size of the storage itself. The num_batches * batch_size can however be bigger than the storage size if :attr:`replacement = True`. batch_size_bounds (tuple of int, None): if :attr:`batch_size` is None, we can instead specify the lower and upper bounds for the `batch_size`. For instance, we can set it to `(16, 128)` along with `batch_size=None`; this will result to compute `batch_size = size(storage) // num_batches` but if this one is too small (<16), it will be set to 16, and if this one is too big (>128), it will be set to 128. replacement (bool): if we should sample each element only one time, or we can sample the same ones multiple times. verbose (int, bool): verbose level, select between {0, 1, 2}. If 0=False, it won't print anything. If 1=True, it will print basic information about the sampler. If verbose=2, it will print detailed information. """ # set the storage self.storage = storage # set variables self._num_batches = num_batches self._replacement = bool(replacement) self._batch_size_bounds = batch_size_bounds self._batch_size_given = batch_size is not None self._verbose = verbose # set the sampler if sampler is None: # check batch size and compute it if necessary if batch_size is None: batch_size = self.size // num_batches # check batch size bounds if isinstance(batch_size_bounds, (tuple, list)) and len(batch_size_bounds) == 2: if batch_size < batch_size_bounds[0]: batch_size = batch_size_bounds[0] elif batch_size > batch_size_bounds[1]: batch_size = batch_size_bounds[1] # check the batch size * number of batches wrt the storage size if batch_size * num_batches > self.size and not self.replacement: raise ValueError( "Expecting the batch size (={}) * num_batches (={}) to be smaller than the size of " "the storage (={}), if we can not use replacement.".format( batch_size, num_batches, self.size)) # subsampler if replacement: subsampler = torch_sampler.RandomSampler( data_source=range(self.size), replacement=replacement, num_samples=self.size) else: subsampler = torch_sampler.SubsetRandomSampler( indices=range(self.size)) # create sampler sampler = torch_sampler.BatchSampler(sampler=subsampler, batch_size=batch_size, drop_last=True) self.sampler = sampler if verbose: print( "\nCreating sampler with size: {} - num batches: {} - batch size: {}" .format(self.size, num_batches, self.batch_size))
def main_worker(gpu, ngpus_per_node, args): global best_acc1 args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() if args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: # best_acc1 may be from a checkpoint from a different GPU best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True # Data loading code traindir = os.path.join(args.data_dir, args.dataset, 'train') valdir = os.path.join(args.data_dir, args.dataset, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # train_labeled_dataset = ImageNet( # traindir, args, # transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ]), # db_path='./data_split/labeled_images_0.10.pth', # ) # # train_unlabeled_dataset = ImageNet( # traindir, args, # transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # normalize, # ]), # db_path='./data_split/unlabeled_images_0.90.pth', # is_unlabeled=True, # ) normalize = transforms.Normalize(mean=[0.482, 0.458, 0.408], std=[0.269, 0.261, 0.276]) train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ # transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) labeled_idx = relabel_dataset(train_dataset, args.size) print("N_train_labeled:{}".format(labeled_idx.size)) print("mean %f, max %f" % (labeled_idx.mean(), labeled_idx.max())) l_sampler = sampler.SubsetRandomSampler(labeled_idx) train_sampler = sampler.BatchSampler(l_sampler, args.batch_size, drop_last=False) train_labeled_loader = torch.utils.data.DataLoader( train_dataset, batch_sampler=train_sampler, num_workers=args.workers, pin_memory=True) train_unlabeled_loader = None val_loader = torch.utils.data.DataLoader( datasets.ImageFolder( valdir, transforms.Compose([ # transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) entropy_criterion = HLoss() iter_sup = iter(train_labeled_loader) if train_unlabeled_loader is None: iter_unsup = None else: iter_unsup = iter(train_unlabeled_loader) model.train() meters = initialize_meters() for train_iter in trange(args.max_iter): lr = adjust_learning_rate(optimizer, train_iter + 1, args) try: batch_data = next(iter_sup) except StopIteration: iter_sup = iter(train_labeled_loader) batch_data = next(iter_sup) un_batch_data = None if iter_unsup is not None: try: un_batch_data = next(iter_unsup) except StopIteration: iter_unsup = iter(train_unlabeled_loader) un_batch_data = next(iter_unsup) train(batch_data, model, optimizer, criterion, un_batch_data, entropy_criterion, meters, args) if (train_iter + 1) % args.print_freq == 0: mes = 'ITER: [{0}/{1}] Data time {data_time.val:.3f} ({data_time.avg:.3f}) Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' \ 'Loss {loss.val:.4f} ({loss.avg:.4f}) HLoss {h_loss.val:.4f} ({h_loss.avg:.4f}) Unsup Loss {unsup_loss.val:.4f} ({unsup_loss.avg:.4f})' \ ' Prec@1 {top1.val:.3f} ({top1.avg:.3f}) Prec@5 {top5.val:.3f} ({top5.avg:.3f}) Learning rate {2}' \ .format(train_iter+1, args.max_iter, lr, data_time=meters['data_time'], batch_time=meters['batch_time'], loss=meters['losses'], h_loss=meters['losses_entropy'], unsup_loss=meters['losses_unsup'], top1=meters['top1'], top5=meters['top5']) tqdm.write(mes) if (train_iter + 1) % args.eval_iter == 0: # evaluate on validation set acc1, acc5 = validate(val_loader, model, criterion, args) if args.vis: dicts = {"Test/Acc": acc1, "Test/Top5Acc": acc5} vis_step(args.writer, (train_iter + 1) / args.eval_iter, dicts) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) save_checkpoint( { 'iter': train_iter + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, is_best, args.dir_path) model.train() meters = initialize_meters()
def get_video_dataloader(mode='train', videos_path=os.environ['HOME'] + '/Database/MSR-VTT/train-video/', vocab_path='data/processed/msrvtt_vocab.pkl', captions_path='data/processed/msrvtt_captions.csv', batch_size=32, num_frames=40, max_len=30, embedding_size=2048, num_captions=20, load_features=False, load_captions=False, preload=False, model='resnet152', num_workers=0): """ Generate a dataloader with the specified parameters. Args: mode: Dataset type to load videos_path: Path to MSR-VTT videos dataset vocab_path: Path to MSR-VTT vocab file caption_size: Path to captions vocab file batch_size: Batch size for Dataloader num_frames: Number of frames per video to process max_len: Max caption length embedding_size: Size of image embedding num_captions: Number of captions per image in dataset load_features: Boolean for creating or loading image features load_captions: Boolean for creating or loading image captions preload: Boolean for either preloading data into RAM during construction model: base model for encoderCNN num_workers: Dataloader parameter Return: data_loader: A torch dataloader for the MSR-VTT dataset """ # Ensure specified mode is validate try: assert mode in ['train', 'dev', 'test'] except AssertionError: print('Invalid mode specified: {}'.format(mode)) print(' Defaulting to dev mode') mode = 'dev' # Build dataset data = VideoDataset(mode, videos_path, vocab_path, captions_path, batch_size, num_frames, max_len, embedding_size, num_captions, load_features, load_captions, preload, model) if mode == 'train': # Get all possible video indices indices = data.get_indices() # Initialize a sampler for the indices init_sampler = sampler.SubsetRandomSampler(indices=indices) # Create data loader with dataset and sampler data_loader = DataLoader(dataset=data, num_workers=num_workers, batch_sampler=sampler.BatchSampler( sampler=init_sampler, batch_size=batch_size, drop_last=False)) else: data_loader = DataLoader(dataset=data, batch_size=batch_size, shuffle=True, num_workers=num_workers) return data_loader
# Test this module if __name__ == '__main__': N_TESTS = 4 passed = 0 dataset = Dataset(transform=unrel.TRANSFORM) # Test on a subset sampler batch_sampler = torchsampler.SequentialSampler(range(14)) dataloader = FauxDataLoader(dataset, sampler=batch_sampler) for batch_i, batch in enumerate(dataloader): assert isinstance(batch['image'], list) for image in batch['image']: assert isinstance(image, torch.Tensor) print('dataset count %3d / %3d' % ((1+batch_i) * dataloader.sampler.batch_size, len(dataloader))) passed += 1; print('OK %d/%d' % (passed, N_TESTS)) # Test on a batched subset sampler batch_sampler = torchsampler.BatchSampler(torchsampler.SequentialSampler(range(14)), 3, False) dataloader = FauxDataLoader(dataset, sampler=batch_sampler) for batch_i, batch in enumerate(dataloader): assert isinstance(batch['image'], list) for image in batch['image']: assert isinstance(image, torch.Tensor) print('dataset count %3d / %3d' % ((1+batch_i) * dataloader.sampler.batch_size, len(dataloader))) passed += 1; print('OK %d/%d' % (passed, N_TESTS)) # Test second time on same dataloader to ensure that reset works for batch_i, batch in enumerate(dataloader): assert isinstance(batch['image'], list) for image in batch['image']: assert isinstance(image, torch.Tensor) print('dataset count %3d / %3d' % ((1+batch_i) * dataloader.sampler.batch_size, len(dataloader))) passed += 1; print('OK %d/%d' % (passed, N_TESTS)) # Test without supplying sampler
def get_data_loader( transform: tv.transforms, caption_file: str, image_id_file: str, image_folder: str, config: Config, vocab_file: str, mode: str = "train", batch_size: int = 1, vocab_threshold=None, start_word: str = "<start>", end_word: str = "<end>", unk_word: str = "<unk>", vocab_from_file: bool = True, num_workers: int = 0, ): """Returns the data loader :param transform: [description] :type transform: tv.transforms :param mode: [description], defaults to "train" :type mode: str, optional :param batch_size: [description], defaults to 1 :type batch_size: int, optional :param vocab_threshold: [description], defaults to None :type vocab_threshold: [type], optional :param vocab_file: [description], defaults to "output/vocab.pkl" :type vocab_file: str, optional :param start_word: [description], defaults to "<start>" :type start_word: str, optional :param end_word: [description], defaults to "<end>" :type end_word: str, optional :param unk_word: [description], defaults to "<unk>" :type unk_word: str, optional :param vocab_from_file: [description], defaults to True :type vocab_from_file: bool, optional :param num_workers: [description], defaults to 0 :type num_workers: int, optional """ assert mode in [ "train", "validation", "test", ], f"mode: '{mode}' must be one of ['train','validation','test']" if vocab_from_file == False: assert ( mode == "train" ), f"mode: '{mode}', but to generate vocab from caption file, mode must be 'train' " if mode == "train": if vocab_from_file == True: assert os.path.exists( vocab_file ), "vocab_file does not exist. Change vocab_from_file to False to create vocab_file." assert image_id_file.find( "train" ), f"double check image_id_file: {image_id_file}. File name should have the substring 'train'" assert os.path.exists( image_id_file ), f"image id file: {image_id_file} doesn't not exist." assert os.path.exists( caption_file ), f"caption file: {caption_file} doesn't not exist." assert os.path.isdir( config.IMAGE_DATA_DIR ), f"{config.IMAGE_DATA_DIR} not a directory" assert ( len(os.listdir(config.IMAGE_DATA_DIR)) != 0 ), f"{config.IMAGE_DATA_DIR} is empty." if mode == "validation": assert image_id_file.find( "dev" ), f"double check image_id_file: {image_id_file}. File name should have the substring 'dev' " assert os.path.exists( image_id_file ), f"image id file: {image_id_file} doesn't not exist." assert os.path.exists( caption_file ), f"caption file: {caption_file} doesn't not exist." assert os.path.isdir( config.IMAGE_DATA_DIR ), f"{config.IMAGE_DATA_DIR} not a directory" assert ( len(os.listdir(config.IMAGE_DATA_DIR)) != 0 ), f"{config.IMAGE_DATA_DIR} is empty." assert os.path.exists( vocab_file ), f"Must first generate {vocab_file} from training data." assert vocab_from_file == True, "Change vocab_from_file to True." if mode == "test": assert ( batch_size == 1 ), "Please change batch_size to 1 if testing your model." assert image_id_file.find( "test" ), f"double check image_id_file: {image_id_file}. File name should have the substring 'test'" assert os.path.exists( vocab_file ), f"Must first generate {vocab_file} from training data." assert vocab_from_file == True, "Change vocab_from_file to True." img_folder = config.IMAGE_DATA_DIR annotations_file = caption_file # image caption dataset dataset = FlickrDataset( transform, mode, batch_size, vocab_threshold, vocab_file, start_word, end_word, unk_word, caption_file, image_id_file, vocab_from_file, image_folder, ) if mode in ["train", "validation"]: # Randomly sample a caption length, and sample indices with that length. indices = dataset.get_train_indices() # Create and assign a batch sampler to retrieve a batch with the sampled indices. initial_sampler = sampler.SubsetRandomSampler(indices=indices) # data loader for COCO dataset. data_loader = DataLoader( dataset=dataset, num_workers=num_workers, batch_sampler=sampler.BatchSampler( sampler=initial_sampler, batch_size=dataset.batch_size, drop_last=False, ), ) else: data_loader = DataLoader( dataset=dataset, batch_size=dataset.batch_size, shuffle=False, num_workers=num_workers, ) return data_loader
def main(): """Main function""" args = parse_args() if DEBUG: args.dataset = 'soma' args.cfg_file = '../configs/soma_starting/e2e_mask_rcnn_soma_dsn_body.yaml' args.num_workers = 0 args.batch_size = 2 args.use_tfboard = False args.no_save = True args.set_cfgs = ['DEBUG', True] # args.resume = True # args.load_ckpt = '/ckpt/model_step8999.pth' print('Called with args:') print(args) if not torch.cuda.is_available(): sys.exit("Need a CUDA device to run the code.") if args.cuda or cfg.NUM_GPUS > 0: cfg.CUDA = True else: raise ValueError("Need Cuda device to run !") if args.dataset == "coco2017": cfg.TRAIN.DATASETS = ('coco_2017_train',) cfg.MODEL.NUM_CLASSES = 81 elif args.dataset == "keypoints_coco2017": cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',) cfg.MODEL.NUM_CLASSES = 2 elif args.dataset == "soma": cfg.TRAIN.DATASETS = ('soma_det_seg_train',) cfg.MODEL.NUM_CLASSES = 2 elif args.dataset == "nuclei": cfg.TRAIN.DATASETS = ('nuclei_det_seg_train',) cfg.MODEL.NUM_CLASSES = 2 else: raise ValueError("Unexpected args.dataset: {}".format(args.dataset)) cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) ### Adaptively adjust some configs ### original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH original_num_gpus = cfg.NUM_GPUS if args.batch_size is None: args.batch_size = original_batch_size cfg.NUM_GPUS = torch.cuda.device_count() assert (args.batch_size % cfg.NUM_GPUS) == 0, \ 'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS) cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS effective_batch_size = args.iter_size * args.batch_size print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size)) print('Adaptive config changes:') print(' effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size)) print(' NUM_GPUS: %d --> %d' % (original_num_gpus, cfg.NUM_GPUS)) print(' IMS_PER_BATCH: %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH)) ### Adjust learning based on batch size change linearly # For iter_size > 1, gradients are `accumulated`, so lr is scaled based # on batch_size instead of effective_batch_size old_base_lr = cfg.SOLVER.BASE_LR cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size print('Adjust BASE_LR linearly according to batch_size change:\n' ' BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR)) ### Adjust solver steps step_scale = original_batch_size / effective_batch_size old_solver_steps = cfg.SOLVER.STEPS old_max_iter = cfg.SOLVER.MAX_ITER cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS)) cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5) print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n' ' SOLVER.STEPS: {} --> {}\n' ' SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS, old_max_iter, cfg.SOLVER.MAX_ITER)) # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function # of `collect_and_distribute_fpn_rpn_proposals.py` # # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5) if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN: cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n' ' cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE)) if args.num_workers is not None: cfg.DATA_LOADER.NUM_THREADS = args.num_workers print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS) ### Overwrite some solver settings from command line arguments if args.optimizer is not None: cfg.SOLVER.TYPE = args.optimizer if args.lr is not None: cfg.SOLVER.BASE_LR = args.lr if args.lr_decay_gamma is not None: cfg.SOLVER.GAMMA = args.lr_decay_gamma assert_and_infer_cfg() timers = defaultdict(Timer) ### Dataset ### timers['roidb'].tic() roidb, ratio_list, ratio_index = combined_roidb_for_training( cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES) timers['roidb'].toc() roidb_size = len(roidb) logger.info('{:d} roidb entries'.format(roidb_size)) logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time) # Effective training sample size for one epoch train_size = roidb_size // args.batch_size * args.batch_size batchSampler = torch_sampler.BatchSampler( #sampler=MinibatchSampler(ratio_list, ratio_index), sampler=torch_sampler.RandomSampler(roidb),#SequentialSampler(roidb), batch_size=args.batch_size, drop_last=True ) dataset = RoiDataLoader( roidb, cfg.MODEL.NUM_CLASSES, training=True) dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=batchSampler, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch) dataiterator = iter(dataloader) ### Model ### maskRCNN = Generalized_RCNN() if cfg.CUDA: maskRCNN.cuda() ### Optimizer ### gn_param_nameset = set() for name, module in maskRCNN.named_modules(): if isinstance(module, nn.GroupNorm): gn_param_nameset.add(name+'.weight') gn_param_nameset.add(name+'.bias') gn_params = [] gn_param_names = [] bias_params = [] bias_param_names = [] nonbias_params = [] nonbias_param_names = [] nograd_param_names = [] for key, value in maskRCNN.named_parameters(): if value.requires_grad: if 'bias' in key: bias_params.append(value) bias_param_names.append(key) elif key in gn_param_nameset: gn_params.append(value) gn_param_names.append(key) else: nonbias_params.append(value) nonbias_param_names.append(key) else: nograd_param_names.append(key) assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names) # Learning rate of 0 is a dummy value to be set properly at the start of training params = [ {'params': nonbias_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY}, {'params': bias_params, 'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1), 'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0}, {'params': gn_params, 'lr': 0, 'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN} ] # names of paramerters for each paramter param_names = [nonbias_param_names, bias_param_names, gn_param_names] if cfg.SOLVER.TYPE == "SGD": optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM) elif cfg.SOLVER.TYPE == "Adam": optimizer = torch.optim.Adam(params) # Load checkpoint if args.load_ckpt: load_name = args.load_ckpt logging.info("loading checkpoint %s", load_name) checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage) #net_utils.load_ckpt(maskRCNN, checkpoint['model']) maskRCNN.load_state_dict(checkpoint['model']) if args.resume: args.start_step = checkpoint['step'] + 1 if 'train_size' in checkpoint: # For backward compatibility if checkpoint['train_size'] != train_size: print('train_size value: %d different from the one in checkpoint: %d' % (train_size, checkpoint['train_size'])) # reorder the params in optimizer checkpoint's params_groups if needed # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint) # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1. # However it's fixed on master. optimizer.load_state_dict(checkpoint['optimizer']) # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer']) del checkpoint torch.cuda.empty_cache() if args.load_detectron: #TODO resume for detectron weights (load sgd momentum values) logging.info("loading Detectron weights %s", args.load_detectron) load_detectron_weight(maskRCNN, args.load_detectron) lr = optimizer.param_groups[0]['lr'] # lr of non-bias parameters, for commmand line outputs. maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'], minibatch=True) ### Training Setups ### args.run_name = misc_utils.get_run_name() + '_step' output_dir = misc_utils.get_output_dir(args, args.run_name) args.cfg_filename = os.path.basename(args.cfg_file) if not args.no_save: if not os.path.exists(output_dir): os.makedirs(output_dir) blob = {'cfg': yaml.dump(cfg), 'args': args} with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f: pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL) if args.use_tfboard: from tensorboardX import SummaryWriter # Set the Tensorboard logger tblogger = SummaryWriter(output_dir) ### Training Loop ### maskRCNN.train() CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS) # Set index for decay steps decay_steps_ind = None for i in range(1, len(cfg.SOLVER.STEPS)): if cfg.SOLVER.STEPS[i] >= args.start_step: decay_steps_ind = i break if decay_steps_ind is None: decay_steps_ind = len(cfg.SOLVER.STEPS) training_stats = TrainingStats( args, args.disp_interval, tblogger if args.use_tfboard and not args.no_save else None) try: logger.info('Training starts !') step = args.start_step for step in range(args.start_step, cfg.SOLVER.MAX_ITER): # Warm up if step < cfg.SOLVER.WARM_UP_ITERS: method = cfg.SOLVER.WARM_UP_METHOD if method == 'constant': warmup_factor = cfg.SOLVER.WARM_UP_FACTOR elif method == 'linear': alpha = step / cfg.SOLVER.WARM_UP_ITERS warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha else: raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method)) lr_new = cfg.SOLVER.BASE_LR * warmup_factor net_utils.update_learning_rate(optimizer, lr, lr_new) lr = optimizer.param_groups[0]['lr'] assert lr == lr_new elif step == cfg.SOLVER.WARM_UP_ITERS: net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR) lr = optimizer.param_groups[0]['lr'] assert lr == cfg.SOLVER.BASE_LR # Learning rate decay if decay_steps_ind < len(cfg.SOLVER.STEPS) and \ step == cfg.SOLVER.STEPS[decay_steps_ind]: logger.info('Decay the learning on step %d', step) lr_new = lr * cfg.SOLVER.GAMMA net_utils.update_learning_rate(optimizer, lr, lr_new) lr = optimizer.param_groups[0]['lr'] assert lr == lr_new decay_steps_ind += 1 training_stats.IterTic() optimizer.zero_grad() for inner_iter in range(args.iter_size): try: input_data = next(dataiterator) except StopIteration: if cfg.TRAIN.NEED_CROP: roidb, ratio_list, ratio_index = combined_roidb_for_training( cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES) batchSampler = torch_sampler.BatchSampler( # sampler=MinibatchSampler(ratio_list, ratio_index), sampler=torch_sampler.RandomSampler(roidb), batch_size=args.batch_size, drop_last=True ) dataset = RoiDataLoader( roidb, cfg.MODEL.NUM_CLASSES, training=True) dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=batchSampler, num_workers=cfg.DATA_LOADER.NUM_THREADS, collate_fn=collate_minibatch) dataiterator = iter(dataloader) input_data = next(dataiterator) for key in input_data: if key != 'roidb': # roidb is a list of ndarrays with inconsistent length input_data[key] = list(map(Variable, input_data[key])) net_outputs = maskRCNN(**input_data) training_stats.UpdateIterStats(net_outputs, inner_iter) loss = net_outputs['total_loss'] loss.backward() optimizer.step() training_stats.IterToc() training_stats.LogIterStats(step, lr) if (step+1) % CHECKPOINT_PERIOD == 0: save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) # ---- Training ends ---- # Save last checkpoint save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) except (RuntimeError, KeyboardInterrupt): del dataiterator logger.info('Save ckpt on exception ...') save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer) logger.info('Save ckpt done.') stack_trace = traceback.format_exc() print(stack_trace) finally: if args.use_tfboard and not args.no_save: tblogger.close()