コード例 #1
0
def create_dataloader(config, data, mode):
    dataset = create_dataset(config, data, mode)
    if mode == 'train':
        # create Sampler
        if dist.is_available() and dist.is_initialized():
            train_RandomSampler = distributed.DistributedSampler(dataset)
        else:
            train_RandomSampler = sampler.RandomSampler(dataset, replacement=False)

        train_BatchSampler = sampler.BatchSampler(train_RandomSampler,
                                              batch_size=config.train.batch_size,
                                              drop_last=config.train.dataloader.drop_last)

        # Augment
        collator = get_collate_fn(config)

        # DataLoader
        data_loader = DataLoader(dataset=dataset,
                                batch_sampler=train_BatchSampler,
                                collate_fn=collator,
                                pin_memory=config.train.dataloader.pin_memory,
                                num_workers=config.train.dataloader.work_nums)

    elif mode == 'val':
        if dist.is_available() and dist.is_initialized():
            val_SequentialSampler = distributed.DistributedSampler(dataset)
        else:
            val_SequentialSampler = sampler.SequentialSampler(dataset)

        val_BatchSampler = sampler.BatchSampler(val_SequentialSampler,
                                                batch_size=config.val.batch_size,
                                                drop_last=config.val.dataloader.drop_last)
        data_loader = DataLoader(dataset,
                                batch_sampler=val_BatchSampler,
                                pin_memory=config.val.dataloader.pin_memory,
                                num_workers=config.val.dataloader.work_nums)
    else:
        if dist.is_available() and dist.is_initialized():
            test_SequentialSampler = distributed.DistributedSampler(dataset)
        else:
            test_SequentialSampler = None

        data_loader = DataLoader(dataset,
                                 sampler=test_SequentialSampler,
                                 batch_size=config.test.batch_size,
                                 pin_memory=config.val.dataloader.pin_memory,
                                 num_workers=config.val.dataloader.work_nums)
    return data_loader
コード例 #2
0
 def setup_data(self):
     transform = unrel.TRANSFORM
     # Initialize trainset
     self.trainset = data.Dataset(split='train',
                                  pairs='annotated',
                                  transform=transform)
     if self.opts.train_size:
         print('Using subset of %d from train_set' % self.opts.train_size)
         batch_sampler = sampler.SequentialSampler(
             range(self.opts.train_size))
     else:
         batch_sampler = None
     self.trainloader = data.FauxDataLoader(self.trainset,
                                            sampler=batch_sampler,
                                            batch_size=self.opts.batch_size)
     # Initialize testset
     if self.opts.do_validation:
         self.testset = data.Dataset(split='test',
                                     pairs='annotated',
                                     transform=transform)
         batch_sampler = sampler.BatchSampler(
             sampler.SequentialSampler(self.testset),
             self.opts.test_batch_size, False
         )  # make test set load without shuffling so that we can use Tyler's RecallEvaluator
         self.testloaders = [
             data.FauxDataLoader(self.testset, sampler=batch_sampler)
         ]
     else:
         print('No testset')
         self.testloaders = []
コード例 #3
0
 def __iter__(self):
     for bucket in self.bucket_sampler:
         sorted_sampler = SortedSampler(bucket, self.sort_key)
         for batch in samplers.SubsetRandomSampler(
                 list(
                     samplers.BatchSampler(sorted_sampler, self.batch_size,
                                           self.drop_last))):
             yield [bucket[i] for i in batch]
コード例 #4
0
 def __init__(self, dataset, batch_size, shuffle = True, drop_last = False):
     # buckets list 根据contexts长度分组
     self.buckets = bucket(dataset)  
     # 打乱 list
     if shuffle:
         np.random.shuffle(self.buckets)
         random_samplers = [sampler.RandomSampler(bucket) for bucket in self.buckets]
     else:
         random_samplers = [sampler.SequentialSampler(bucket) for bucket in self.buckets]
     self.sampler = [sampler.BatchSampler(s, batch_size, drop_last) for s in random_samplers]
コード例 #5
0
	def _init_sampler(self, **kwargs):
		sampler = kwargs.get('sampler')
		batch_size  = kwargs.get('batch_size', 1)
		drop_last   = kwargs.get('drop_last', False)
		if isinstance(sampler, torchsampler.BatchSampler):
			return sampler
		if sampler == None:
			sampler = torchsampler.RandomSampler(self.dataset)
		elif not isinstance(sampler, torchsampler.Sampler):
			sampler = torchsampler.RandomSampler(sampler)
		return torchsampler.BatchSampler(sampler, batch_size, drop_last)
コード例 #6
0
 def __init__(
     self,
     sampler,
     batch_size,
     drop_last,
     sort_key,
     bucket_size_multiplier=100,
 ):
     super().__init__(sampler, batch_size, drop_last)
     self.sort_key = sort_key
     self.bucket_sampler = samplers.BatchSampler(
         sampler, min(batch_size * bucket_size_multiplier, len(sampler)),
         False)
コード例 #7
0
def create_training_batch(train_data, batch_size):
    '''

    '''
    students = train_data.keys()
    stud_ids = [stud_id for stud_id in students]
    batches = list(
        sampler.BatchSampler(sampler.SequentialSampler(stud_ids),
                             batch_size=batch_size,
                             drop_last=False))
    batch_ids = []
    for batch in batches:
        batch_ids.append([stud_ids[i] for i in batch])
    return batch_ids
コード例 #8
0
    def __init__(self,
                 input_tensor,
                 input_lengths,
                 labels_tensor,
                 batch_size,
                 sequence_lenght=2665):
        self.input_tensor = input_tensor
        self.input_lengths = input_lengths
        self.labels_tensor = labels_tensor
        self.batch_size = batch_size
        self.sequence_length = 2665

        self.sampler = splr.BatchSampler(
            splr.RandomSampler(self.labels_tensor), self.batch_size, False)
        self.sampler_iter = iter(self.sampler)
コード例 #9
0
ファイル: data.py プロジェクト: SEU-DongHan/BDNet
def test_wegith_sampler():
    from torch.utils.data import sampler
    weight = list([
        1,
    ] * 30)
    weight[:10] = list([
        3,
    ] * 10)
    weight_sampler = sampler.WeightedRandomSampler(weight,
                                                   num_samples=len(weight))
    batch_sampler = sampler.BatchSampler(weight_sampler,
                                         batch_size=4,
                                         drop_last=False)
    for indices in batch_sampler:
        print(indices)
コード例 #10
0
ファイル: Training.py プロジェクト: hilfe123/bot
    def train_network(self,game_counter):
        if game_counter <=self.trajectory_size-1:
            ub_index = game_counter * 30 - 1
        else: ub_index = self.trajectory_size*30

        indeces = list(sampler.BatchSampler(sampler.SubsetRandomSampler(range(ub_index)), batch_size=300, drop_last=False))
        i = 0

        for index in indeces:
            if i >= 1:
                return
            self.optimizer.zero_grad()

            z = self.trajectory.outcome_values[index]
            prob = self.trajectory.mc_probs[index]
            v,net_prob = self.learning_network(self.trajectory.states[index])
            cross_entropy_loss =self.cross_entropy(net_prob,prob)

            loss = ((z - v).pow(2) + cross_entropy_loss).mean()
            loss.backward()
            self.optimizer.step()
            i +=1
            print(loss.item())
コード例 #11
0
def get_image_dataloader(mode='train',
                         coco_set=2014,
                         images_path=os.environ[
                             'HOME'] + '/Database/coco/images/',
                         vocab_path='data/processed/coco_vocab.pkl',
                         captions_path='data/processed/coco_captions.csv',
                         batch_size=32,
                         max_len=30,
                         embedding_size=2048,
                         num_captions=5,
                         load_features=False,
                         load_captions=False,
                         preload=False,
                         model='resnet152',
                         num_workers=0):
    """
    Generate a dataloader with the specified parameters.

    Args:
        mode: Dataset type to load
        coco_set: COCO dataset year to load
        images_path: Path to COCO dataset images
        vocab_path: Path to COCO vocab file
        caption_size: Path to captions vocab file
        batch_size: Batch size for Dataloader
        max_len: Max caption length
        embedding_size: Size of image embedding
        num_captions: Number of captions per image in dataset
        load_features: Boolean for creating or loading image features
        load_captions: Boolean for creating or loading image captions
        preload: Boolean for either preloading data
           into RAM during construction
        model: base model for encoderCNN
        num_workers: Dataloader parameter

    Return:
        data_loader: A torch dataloader for the specified coco dataset

    """
    # Ensure that specified mode is valid
    try:
        assert mode in ['train', 'val', 'test']
        assert coco_set in [2014, 2017]
        assert os.path.exists(images_path)
        assert os.path.exists(vocab_path)
        assert os.path.exists(captions_path)
    except AssertionError:
        # Defaulting conditions
        if mode not in ['train', 'val', 'test']:
            print('Invalid mode specified: ' +
                  '{}. Defaulting to val mode'.format(mode))
            mode = 'val'
        if coco_set not in [2014, 2017]:
            print('Invalid coco year specified: ' +
                  '{}. Defaulting to 2014'.format(coco_set))
            coco_set = 2014

        # Terminating conditions
        if not os.path.exists(images_path):
            print(images_path + " does not exist!")
            return None
        elif not os.path.exists(vocab_path):
            print(vocab_path + " does not exist!")
            return None
        elif not os.path.exists(captions_path):
            print(captions_path + " does not exist!")
            return None

    # Generate dataset
    data = ImageDataset(mode, coco_set, images_path, vocab_path,
                        captions_path, batch_size, max_len,
                        embedding_size, num_captions, load_features,
                        load_captions, preload, model)

    # Create a dataloader -- only randomly sample when
    # training
    if mode == 'train':
        # Get all possible image indices
        indices = data.get_indices()

        # Initialize a sampler for the indices
        init_sampler = sampler.SubsetRandomSampler(
            indices=indices)

        # Create data loader with dataset and sampler
        data_loader = DataLoader(dataset=data,
                                 num_workers=num_workers,
                                 batch_sampler=sampler.BatchSampler(
                                     sampler=init_sampler,
                                     batch_size=batch_size,
                                     drop_last=False))

    else:
        data_loader = DataLoader(dataset=data,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=num_workers)

    return data_loader
コード例 #12
0
ファイル: sampler.py プロジェクト: ultrainren/pyrobolearn
    def __init__(self,
                 storage,
                 sampler=None,
                 num_batches=10,
                 batch_size=None,
                 batch_size_bounds=None,
                 replacement=True,
                 verbose=0):
        """
        Initialize the storage sampler.

        Args:
            storage (Storage): storage sampler.
            sampler (Sampler, None): If None, it will use a sampler that randomly sample batches of the storage. It
                will by default sample :attr:`num_batches`.
            num_batches (int): number of batches.
            batch_size (int, None): size of the batch. If None, it will be computed based on the size of the storage,
                where batch_size = size(storage) // num_batches. Note that the batch size must be smaller than the size
                of the storage itself. The num_batches * batch_size can however be bigger than the storage size if
                :attr:`replacement = True`.
            batch_size_bounds (tuple of int, None): if :attr:`batch_size` is None, we can instead specify the lower
                and upper bounds for the `batch_size`. For instance, we can set it to `(16, 128)` along with
                `batch_size=None`; this will result to compute `batch_size = size(storage) // num_batches` but if this
                one is too small (<16), it will be set to 16, and if this one is too big (>128), it will be set to 128.
            replacement (bool): if we should sample each element only one time, or we can sample the same ones multiple
                times.
            verbose (int, bool): verbose level, select between {0, 1, 2}. If 0=False, it won't print anything. If
                1=True, it will print basic information about the sampler. If verbose=2, it will print detailed
                information.
        """
        # set the storage
        self.storage = storage

        # set variables
        self._num_batches = num_batches
        self._replacement = bool(replacement)
        self._batch_size_bounds = batch_size_bounds
        self._batch_size_given = batch_size is not None
        self._verbose = verbose

        # set the sampler
        if sampler is None:

            # check batch size and compute it if necessary
            if batch_size is None:
                batch_size = self.size // num_batches

            # check batch size bounds
            if isinstance(batch_size_bounds,
                          (tuple, list)) and len(batch_size_bounds) == 2:
                if batch_size < batch_size_bounds[0]:
                    batch_size = batch_size_bounds[0]
                elif batch_size > batch_size_bounds[1]:
                    batch_size = batch_size_bounds[1]

            # check the batch size * number of batches wrt the storage size
            if batch_size * num_batches > self.size and not self.replacement:
                raise ValueError(
                    "Expecting the batch size (={}) * num_batches (={}) to be smaller than the size of "
                    "the storage (={}), if we can not use replacement.".format(
                        batch_size, num_batches, self.size))

            # subsampler
            if replacement:
                subsampler = torch_sampler.RandomSampler(
                    data_source=range(self.size),
                    replacement=replacement,
                    num_samples=self.size)
            else:
                subsampler = torch_sampler.SubsetRandomSampler(
                    indices=range(self.size))

            # create sampler
            sampler = torch_sampler.BatchSampler(sampler=subsampler,
                                                 batch_size=batch_size,
                                                 drop_last=True)

        self.sampler = sampler

        if verbose:
            print(
                "\nCreating sampler with size: {} - num batches: {} - batch size: {}"
                .format(self.size, num_batches, self.batch_size))
コード例 #13
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data_dir, args.dataset, 'train')
    valdir = os.path.join(args.data_dir, args.dataset, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # train_labeled_dataset = ImageNet(
    #     traindir, args,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]),
    #     db_path='./data_split/labeled_images_0.10.pth',
    #     )
    #
    # train_unlabeled_dataset = ImageNet(
    #     traindir, args,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]),
    #     db_path='./data_split/unlabeled_images_0.90.pth',
    #     is_unlabeled=True,
    #     )

    normalize = transforms.Normalize(mean=[0.482, 0.458, 0.408],
                                     std=[0.269, 0.261, 0.276])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            # transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    labeled_idx = relabel_dataset(train_dataset, args.size)
    print("N_train_labeled:{}".format(labeled_idx.size))
    print("mean %f, max %f" % (labeled_idx.mean(), labeled_idx.max()))
    l_sampler = sampler.SubsetRandomSampler(labeled_idx)
    train_sampler = sampler.BatchSampler(l_sampler,
                                         args.batch_size,
                                         drop_last=False)

    train_labeled_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        num_workers=args.workers,
        pin_memory=True)
    train_unlabeled_loader = None

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            valdir,
            transforms.Compose([
                # transforms.Resize(256), transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)

    entropy_criterion = HLoss()

    iter_sup = iter(train_labeled_loader)
    if train_unlabeled_loader is None:
        iter_unsup = None
    else:
        iter_unsup = iter(train_unlabeled_loader)

    model.train()
    meters = initialize_meters()
    for train_iter in trange(args.max_iter):

        lr = adjust_learning_rate(optimizer, train_iter + 1, args)
        try:
            batch_data = next(iter_sup)
        except StopIteration:
            iter_sup = iter(train_labeled_loader)
            batch_data = next(iter_sup)

        un_batch_data = None
        if iter_unsup is not None:
            try:
                un_batch_data = next(iter_unsup)
            except StopIteration:
                iter_unsup = iter(train_unlabeled_loader)
                un_batch_data = next(iter_unsup)

        train(batch_data, model, optimizer, criterion, un_batch_data,
              entropy_criterion, meters, args)

        if (train_iter + 1) % args.print_freq == 0:
            mes = 'ITER: [{0}/{1}]  Data time {data_time.val:.3f} ({data_time.avg:.3f}) Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' \
                  'Loss {loss.val:.4f} ({loss.avg:.4f})  HLoss {h_loss.val:.4f} ({h_loss.avg:.4f})  Unsup Loss {unsup_loss.val:.4f} ({unsup_loss.avg:.4f})' \
                  '  Prec@1 {top1.val:.3f} ({top1.avg:.3f})   Prec@5 {top5.val:.3f} ({top5.avg:.3f})   Learning rate {2}' \
                .format(train_iter+1, args.max_iter, lr, data_time=meters['data_time'], batch_time=meters['batch_time'],
                        loss=meters['losses'], h_loss=meters['losses_entropy'], unsup_loss=meters['losses_unsup'],
                        top1=meters['top1'],  top5=meters['top5'])
            tqdm.write(mes)
        if (train_iter + 1) % args.eval_iter == 0:
            # evaluate on validation set
            acc1, acc5 = validate(val_loader, model, criterion, args)
            if args.vis:
                dicts = {"Test/Acc": acc1, "Test/Top5Acc": acc5}
                vis_step(args.writer, (train_iter + 1) / args.eval_iter, dicts)

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            save_checkpoint(
                {
                    'iter': train_iter + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.dir_path)
            model.train()
            meters = initialize_meters()
コード例 #14
0
def get_video_dataloader(mode='train',
                         videos_path=os.environ['HOME'] +
                         '/Database/MSR-VTT/train-video/',
                         vocab_path='data/processed/msrvtt_vocab.pkl',
                         captions_path='data/processed/msrvtt_captions.csv',
                         batch_size=32,
                         num_frames=40,
                         max_len=30,
                         embedding_size=2048,
                         num_captions=20,
                         load_features=False,
                         load_captions=False,
                         preload=False,
                         model='resnet152',
                         num_workers=0):
    """
    Generate a dataloader with the specified parameters.

    Args:
        mode: Dataset type to load
        videos_path: Path to MSR-VTT videos dataset
        vocab_path: Path to MSR-VTT vocab file
        caption_size: Path to captions vocab file
        batch_size: Batch size for Dataloader
        num_frames: Number of frames per video to process
        max_len: Max caption length
        embedding_size: Size of image embedding
        num_captions: Number of captions per image in dataset
        load_features: Boolean for creating or loading image features
        load_captions: Boolean for creating or loading image captions
        preload: Boolean for either preloading data
           into RAM during construction
        model: base model for encoderCNN
        num_workers: Dataloader parameter

    Return:
        data_loader: A torch dataloader for the MSR-VTT dataset

    """
    # Ensure specified mode is validate
    try:
        assert mode in ['train', 'dev', 'test']
    except AssertionError:
        print('Invalid mode specified: {}'.format(mode))
        print(' Defaulting to dev mode')
        mode = 'dev'

    # Build dataset
    data = VideoDataset(mode, videos_path, vocab_path, captions_path,
                        batch_size, num_frames, max_len, embedding_size,
                        num_captions, load_features, load_captions, preload,
                        model)

    if mode == 'train':
        # Get all possible video indices
        indices = data.get_indices()

        # Initialize a sampler for the indices
        init_sampler = sampler.SubsetRandomSampler(indices=indices)

        # Create data loader with dataset and sampler
        data_loader = DataLoader(dataset=data,
                                 num_workers=num_workers,
                                 batch_sampler=sampler.BatchSampler(
                                     sampler=init_sampler,
                                     batch_size=batch_size,
                                     drop_last=False))
    else:
        data_loader = DataLoader(dataset=data,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=num_workers)
    return data_loader
コード例 #15
0
# Test this module
if __name__ == '__main__':
	N_TESTS = 4
	passed = 0
	dataset = Dataset(transform=unrel.TRANSFORM)
	# Test on a subset sampler
	batch_sampler = torchsampler.SequentialSampler(range(14))
	dataloader = FauxDataLoader(dataset, sampler=batch_sampler)
	for batch_i, batch in enumerate(dataloader):
		assert isinstance(batch['image'], list)
		for image in batch['image']:
			assert isinstance(image, torch.Tensor)
		print('dataset count %3d / %3d' % ((1+batch_i) * dataloader.sampler.batch_size, len(dataloader)))
	passed += 1; print('OK %d/%d' % (passed, N_TESTS))
	# Test on a batched subset sampler
	batch_sampler = torchsampler.BatchSampler(torchsampler.SequentialSampler(range(14)), 3, False)
	dataloader = FauxDataLoader(dataset, sampler=batch_sampler)
	for batch_i, batch in enumerate(dataloader):
		assert isinstance(batch['image'], list)
		for image in batch['image']:
			assert isinstance(image, torch.Tensor)
		print('dataset count %3d / %3d' % ((1+batch_i) * dataloader.sampler.batch_size, len(dataloader)))
	passed += 1; print('OK %d/%d' % (passed, N_TESTS))
	# Test second time on same dataloader to ensure that reset works
	for batch_i, batch in enumerate(dataloader):
		assert isinstance(batch['image'], list)
		for image in batch['image']:
			assert isinstance(image, torch.Tensor)
		print('dataset count %3d / %3d' % ((1+batch_i) * dataloader.sampler.batch_size, len(dataloader)))
	passed += 1; print('OK %d/%d' % (passed, N_TESTS))
	# Test without supplying sampler
コード例 #16
0
def get_data_loader(
    transform: tv.transforms,
    caption_file: str,
    image_id_file: str,
    image_folder: str,
    config: Config,
    vocab_file: str,
    mode: str = "train",
    batch_size: int = 1,
    vocab_threshold=None,
    start_word: str = "<start>",
    end_word: str = "<end>",
    unk_word: str = "<unk>",
    vocab_from_file: bool = True,
    num_workers: int = 0,
):
    """Returns the data loader

    :param transform: [description]
    :type transform: tv.transforms
    :param mode: [description], defaults to "train"
    :type mode: str, optional
    :param batch_size: [description], defaults to 1
    :type batch_size: int, optional
    :param vocab_threshold: [description], defaults to None
    :type vocab_threshold: [type], optional
    :param vocab_file: [description], defaults to "output/vocab.pkl"
    :type vocab_file: str, optional
    :param start_word: [description], defaults to "<start>"
    :type start_word: str, optional
    :param end_word: [description], defaults to "<end>"
    :type end_word: str, optional
    :param unk_word: [description], defaults to "<unk>"
    :type unk_word: str, optional
    :param vocab_from_file: [description], defaults to True
    :type vocab_from_file: bool, optional
    :param num_workers: [description], defaults to 0
    :type num_workers: int, optional
    
    """

    assert mode in [
        "train",
        "validation",
        "test",
    ], f"mode: '{mode}' must be one of ['train','validation','test']"
    if vocab_from_file == False:
        assert (
            mode == "train"
        ), f"mode: '{mode}', but to generate vocab from caption file, mode must be 'train' "

    if mode == "train":
        if vocab_from_file == True:
            assert os.path.exists(
                vocab_file
            ), "vocab_file does not exist.  Change vocab_from_file to False to create vocab_file."
        assert image_id_file.find(
            "train"
        ), f"double check image_id_file: {image_id_file}. File name should have the substring 'train'"
        assert os.path.exists(
            image_id_file
        ), f"image id file: {image_id_file} doesn't not exist."
        assert os.path.exists(
            caption_file
        ), f"caption file: {caption_file} doesn't not exist."
        assert os.path.isdir(
            config.IMAGE_DATA_DIR
        ), f"{config.IMAGE_DATA_DIR} not a directory"
        assert (
            len(os.listdir(config.IMAGE_DATA_DIR)) != 0
        ), f"{config.IMAGE_DATA_DIR} is empty."

    if mode == "validation":
        assert image_id_file.find(
            "dev"
        ), f"double check image_id_file: {image_id_file}. File name should have the substring 'dev' "
        assert os.path.exists(
            image_id_file
        ), f"image id file: {image_id_file} doesn't not exist."
        assert os.path.exists(
            caption_file
        ), f"caption file: {caption_file} doesn't not exist."
        assert os.path.isdir(
            config.IMAGE_DATA_DIR
        ), f"{config.IMAGE_DATA_DIR} not a directory"
        assert (
            len(os.listdir(config.IMAGE_DATA_DIR)) != 0
        ), f"{config.IMAGE_DATA_DIR} is empty."
        assert os.path.exists(
            vocab_file
        ), f"Must first generate {vocab_file} from training data."
        assert vocab_from_file == True, "Change vocab_from_file to True."

    if mode == "test":
        assert (
            batch_size == 1
        ), "Please change batch_size to 1 if testing your model."
        assert image_id_file.find(
            "test"
        ), f"double check image_id_file: {image_id_file}. File name should have the substring 'test'"
        assert os.path.exists(
            vocab_file
        ), f"Must first generate {vocab_file} from training data."
        assert vocab_from_file == True, "Change vocab_from_file to True."

    img_folder = config.IMAGE_DATA_DIR
    annotations_file = caption_file

    # image caption dataset
    dataset = FlickrDataset(
        transform,
        mode,
        batch_size,
        vocab_threshold,
        vocab_file,
        start_word,
        end_word,
        unk_word,
        caption_file,
        image_id_file,
        vocab_from_file,
        image_folder,
    )

    if mode in ["train", "validation"]:
        # Randomly sample a caption length, and sample indices with that length.
        indices = dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        initial_sampler = sampler.SubsetRandomSampler(indices=indices)
        # data loader for COCO dataset.
        data_loader = DataLoader(
            dataset=dataset,
            num_workers=num_workers,
            batch_sampler=sampler.BatchSampler(
                sampler=initial_sampler,
                batch_size=dataset.batch_size,
                drop_last=False,
            ),
        )
    else:
        data_loader = DataLoader(
            dataset=dataset,
            batch_size=dataset.batch_size,
            shuffle=False,
            num_workers=num_workers,
        )

    return data_loader
コード例 #17
0
def main():
    """Main function"""
    args = parse_args()
    if DEBUG:
        args.dataset = 'soma'
        args.cfg_file = '../configs/soma_starting/e2e_mask_rcnn_soma_dsn_body.yaml'
        args.num_workers = 0
        args.batch_size = 2
        args.use_tfboard = False
        args.no_save = True
        args.set_cfgs = ['DEBUG', True]
        # args.resume = True
        # args.load_ckpt = '/ckpt/model_step8999.pth'

    print('Called with args:')
    print(args)

    if not torch.cuda.is_available():
        sys.exit("Need a CUDA device to run the code.")

    if args.cuda or cfg.NUM_GPUS > 0:
        cfg.CUDA = True
    else:
        raise ValueError("Need Cuda device to run !")

    if args.dataset == "coco2017":
        cfg.TRAIN.DATASETS = ('coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 81
    elif args.dataset == "keypoints_coco2017":
        cfg.TRAIN.DATASETS = ('keypoints_coco_2017_train',)
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "soma":
        cfg.TRAIN.DATASETS = ('soma_det_seg_train',)
        cfg.MODEL.NUM_CLASSES = 2
    elif args.dataset == "nuclei":
        cfg.TRAIN.DATASETS = ('nuclei_det_seg_train',)
        cfg.MODEL.NUM_CLASSES = 2
    else:
        raise ValueError("Unexpected args.dataset: {}".format(args.dataset))


    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    ### Adaptively adjust some configs ###
    original_batch_size = cfg.NUM_GPUS * cfg.TRAIN.IMS_PER_BATCH
    original_ims_per_batch = cfg.TRAIN.IMS_PER_BATCH
    original_num_gpus = cfg.NUM_GPUS
    if args.batch_size is None:
        args.batch_size = original_batch_size
    
    cfg.NUM_GPUS = torch.cuda.device_count()

    assert (args.batch_size % cfg.NUM_GPUS) == 0, \
        'batch_size: %d, NUM_GPUS: %d' % (args.batch_size, cfg.NUM_GPUS)
    cfg.TRAIN.IMS_PER_BATCH = args.batch_size // cfg.NUM_GPUS
    effective_batch_size = args.iter_size * args.batch_size
    print('effective_batch_size = batch_size * iter_size = %d * %d' % (args.batch_size, args.iter_size))

    print('Adaptive config changes:')
    print('    effective_batch_size: %d --> %d' % (original_batch_size, effective_batch_size))
    print('    NUM_GPUS:             %d --> %d' % (original_num_gpus, cfg.NUM_GPUS))
    print('    IMS_PER_BATCH:        %d --> %d' % (original_ims_per_batch, cfg.TRAIN.IMS_PER_BATCH))

    ### Adjust learning based on batch size change linearly
    # For iter_size > 1, gradients are `accumulated`, so lr is scaled based
    # on batch_size instead of effective_batch_size
    old_base_lr = cfg.SOLVER.BASE_LR
    cfg.SOLVER.BASE_LR *= args.batch_size / original_batch_size
    print('Adjust BASE_LR linearly according to batch_size change:\n'
          '    BASE_LR: {} --> {}'.format(old_base_lr, cfg.SOLVER.BASE_LR))

    ### Adjust solver steps
    step_scale = original_batch_size / effective_batch_size
    old_solver_steps = cfg.SOLVER.STEPS
    old_max_iter = cfg.SOLVER.MAX_ITER
    cfg.SOLVER.STEPS = list(map(lambda x: int(x * step_scale + 0.5), cfg.SOLVER.STEPS))
    cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * step_scale + 0.5)
    print('Adjust SOLVER.STEPS and SOLVER.MAX_ITER linearly based on effective_batch_size change:\n'
          '    SOLVER.STEPS: {} --> {}\n'
          '    SOLVER.MAX_ITER: {} --> {}'.format(old_solver_steps, cfg.SOLVER.STEPS,
                                                  old_max_iter, cfg.SOLVER.MAX_ITER))

    # Scale FPN rpn_proposals collect size (post_nms_topN) in `collect` function
    # of `collect_and_distribute_fpn_rpn_proposals.py`
    #
    # post_nms_topN = int(cfg[cfg_key].RPN_POST_NMS_TOP_N * cfg.FPN.RPN_COLLECT_SCALE + 0.5)
    if cfg.FPN.FPN_ON and cfg.MODEL.FASTER_RCNN:
        cfg.FPN.RPN_COLLECT_SCALE = cfg.TRAIN.IMS_PER_BATCH / original_ims_per_batch
        print('Scale FPN rpn_proposals collect size directly propotional to the change of IMS_PER_BATCH:\n'
              '    cfg.FPN.RPN_COLLECT_SCALE: {}'.format(cfg.FPN.RPN_COLLECT_SCALE))

    if args.num_workers is not None:
        cfg.DATA_LOADER.NUM_THREADS = args.num_workers
    print('Number of data loading threads: %d' % cfg.DATA_LOADER.NUM_THREADS)

    ### Overwrite some solver settings from command line arguments
    if args.optimizer is not None:
        cfg.SOLVER.TYPE = args.optimizer
    if args.lr is not None:
        cfg.SOLVER.BASE_LR = args.lr
    if args.lr_decay_gamma is not None:
        cfg.SOLVER.GAMMA = args.lr_decay_gamma
    assert_and_infer_cfg()

    timers = defaultdict(Timer)

    ### Dataset ###
    timers['roidb'].tic()
    roidb, ratio_list, ratio_index = combined_roidb_for_training(
        cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)

    timers['roidb'].toc()
    roidb_size = len(roidb)
    logger.info('{:d} roidb entries'.format(roidb_size))
    logger.info('Takes %.2f sec(s) to construct roidb', timers['roidb'].average_time)

    # Effective training sample size for one epoch
    train_size = roidb_size // args.batch_size * args.batch_size

    batchSampler = torch_sampler.BatchSampler(
        #sampler=MinibatchSampler(ratio_list, ratio_index),
        sampler=torch_sampler.RandomSampler(roidb),#SequentialSampler(roidb),
        batch_size=args.batch_size,
        drop_last=True
    )
    dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        training=True)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

    ### Model ###
    maskRCNN = Generalized_RCNN()

    if cfg.CUDA:
        maskRCNN.cuda()

    ### Optimizer ###
    gn_param_nameset = set()
    for name, module in maskRCNN.named_modules():
        if isinstance(module, nn.GroupNorm):
            gn_param_nameset.add(name+'.weight')
            gn_param_nameset.add(name+'.bias')
    gn_params = []
    gn_param_names = []
    bias_params = []
    bias_param_names = []
    nonbias_params = []
    nonbias_param_names = []
    nograd_param_names = []
    for key, value in maskRCNN.named_parameters():
        if value.requires_grad:
            if 'bias' in key:
                bias_params.append(value)
                bias_param_names.append(key)
            elif key in gn_param_nameset:
                gn_params.append(value)
                gn_param_names.append(key)
            else:
                nonbias_params.append(value)
                nonbias_param_names.append(key)
        else:
            nograd_param_names.append(key)
    assert (gn_param_nameset - set(nograd_param_names) - set(bias_param_names)) == set(gn_param_names)

    # Learning rate of 0 is a dummy value to be set properly at the start of training
    params = [
        {'params': nonbias_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY},
        {'params': bias_params,
         'lr': 0 * (cfg.SOLVER.BIAS_DOUBLE_LR + 1),
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY if cfg.SOLVER.BIAS_WEIGHT_DECAY else 0},
        {'params': gn_params,
         'lr': 0,
         'weight_decay': cfg.SOLVER.WEIGHT_DECAY_GN}
    ]
    # names of paramerters for each paramter
    param_names = [nonbias_param_names, bias_param_names, gn_param_names]

    if cfg.SOLVER.TYPE == "SGD":
        optimizer = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
    elif cfg.SOLVER.TYPE == "Adam":
        optimizer = torch.optim.Adam(params)

    # Load checkpoint
    if args.load_ckpt:
        load_name = args.load_ckpt
        logging.info("loading checkpoint %s", load_name)
        checkpoint = torch.load(load_name, map_location=lambda storage, loc: storage)
        #net_utils.load_ckpt(maskRCNN, checkpoint['model'])
        maskRCNN.load_state_dict(checkpoint['model'])
        if args.resume:
            args.start_step = checkpoint['step'] + 1
            if 'train_size' in checkpoint:  # For backward compatibility
                if checkpoint['train_size'] != train_size:
                    print('train_size value: %d different from the one in checkpoint: %d'
                          % (train_size, checkpoint['train_size']))

            # reorder the params in optimizer checkpoint's params_groups if needed
            # misc_utils.ensure_optimizer_ckpt_params_order(param_names, checkpoint)

            # There is a bug in optimizer.load_state_dict on Pytorch 0.3.1.
            # However it's fixed on master.
            optimizer.load_state_dict(checkpoint['optimizer'])
            # misc_utils.load_optimizer_state_dict(optimizer, checkpoint['optimizer'])
        del checkpoint
        torch.cuda.empty_cache()

    if args.load_detectron:  #TODO resume for detectron weights (load sgd momentum values)
        logging.info("loading Detectron weights %s", args.load_detectron)
        load_detectron_weight(maskRCNN, args.load_detectron)

    lr = optimizer.param_groups[0]['lr']  # lr of non-bias parameters, for commmand line outputs.

    maskRCNN = mynn.DataParallel(maskRCNN, cpu_keywords=['im_info', 'roidb'],
                                 minibatch=True)

    ### Training Setups ###
    args.run_name = misc_utils.get_run_name() + '_step'
    output_dir = misc_utils.get_output_dir(args, args.run_name)
    args.cfg_filename = os.path.basename(args.cfg_file)

    if not args.no_save:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        blob = {'cfg': yaml.dump(cfg), 'args': args}
        with open(os.path.join(output_dir, 'config_and_args.pkl'), 'wb') as f:
            pickle.dump(blob, f, pickle.HIGHEST_PROTOCOL)

        if args.use_tfboard:
            from tensorboardX import SummaryWriter
            # Set the Tensorboard logger
            tblogger = SummaryWriter(output_dir)

    ### Training Loop ###
    maskRCNN.train()

    CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

    # Set index for decay steps
    decay_steps_ind = None
    for i in range(1, len(cfg.SOLVER.STEPS)):
        if cfg.SOLVER.STEPS[i] >= args.start_step:
            decay_steps_ind = i
            break
    if decay_steps_ind is None:
        decay_steps_ind = len(cfg.SOLVER.STEPS)

    training_stats = TrainingStats(
        args,
        args.disp_interval,
        tblogger if args.use_tfboard and not args.no_save else None)
    try:
        logger.info('Training starts !')
        step = args.start_step
        for step in range(args.start_step, cfg.SOLVER.MAX_ITER):

            # Warm up
            if step < cfg.SOLVER.WARM_UP_ITERS:
                method = cfg.SOLVER.WARM_UP_METHOD
                if method == 'constant':
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR
                elif method == 'linear':
                    alpha = step / cfg.SOLVER.WARM_UP_ITERS
                    warmup_factor = cfg.SOLVER.WARM_UP_FACTOR * (1 - alpha) + alpha
                else:
                    raise KeyError('Unknown SOLVER.WARM_UP_METHOD: {}'.format(method))
                lr_new = cfg.SOLVER.BASE_LR * warmup_factor
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
            elif step == cfg.SOLVER.WARM_UP_ITERS:
                net_utils.update_learning_rate(optimizer, lr, cfg.SOLVER.BASE_LR)
                lr = optimizer.param_groups[0]['lr']
                assert lr == cfg.SOLVER.BASE_LR

            # Learning rate decay
            if decay_steps_ind < len(cfg.SOLVER.STEPS) and \
                    step == cfg.SOLVER.STEPS[decay_steps_ind]:
                logger.info('Decay the learning on step %d', step)
                lr_new = lr * cfg.SOLVER.GAMMA
                net_utils.update_learning_rate(optimizer, lr, lr_new)
                lr = optimizer.param_groups[0]['lr']
                assert lr == lr_new
                decay_steps_ind += 1

            training_stats.IterTic()
            optimizer.zero_grad()
            for inner_iter in range(args.iter_size):
                try:
                    input_data = next(dataiterator)
                except StopIteration:
                    if cfg.TRAIN.NEED_CROP:
                        roidb, ratio_list, ratio_index = combined_roidb_for_training(
                            cfg.TRAIN.DATASETS, cfg.TRAIN.PROPOSAL_FILES)
                        batchSampler = torch_sampler.BatchSampler(
                            # sampler=MinibatchSampler(ratio_list, ratio_index),
                            sampler=torch_sampler.RandomSampler(roidb),
                            batch_size=args.batch_size,
                            drop_last=True
                        )
                        dataset = RoiDataLoader(
                            roidb,
                            cfg.MODEL.NUM_CLASSES,
                            training=True)

                        dataloader = torch.utils.data.DataLoader(
                            dataset,
                            batch_sampler=batchSampler,
                            num_workers=cfg.DATA_LOADER.NUM_THREADS,
                            collate_fn=collate_minibatch)
                    dataiterator = iter(dataloader)
                    input_data = next(dataiterator)

                for key in input_data:
                    if key != 'roidb': # roidb is a list of ndarrays with inconsistent length
                        input_data[key] = list(map(Variable, input_data[key]))

                net_outputs = maskRCNN(**input_data)
                training_stats.UpdateIterStats(net_outputs, inner_iter)
                loss = net_outputs['total_loss']
                loss.backward()
            optimizer.step()
            training_stats.IterToc()

            training_stats.LogIterStats(step, lr)

            if (step+1) % CHECKPOINT_PERIOD == 0:
                save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

        # ---- Training ends ----
        # Save last checkpoint
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)

    except (RuntimeError, KeyboardInterrupt):
        del dataiterator
        logger.info('Save ckpt on exception ...')
        save_ckpt(output_dir, args, step, train_size, maskRCNN, optimizer)
        logger.info('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    finally:
        if args.use_tfboard and not args.no_save:
            tblogger.close()