def build_validation_data_loader(self): if self.args.val_skip > 1: self.dataset_eval = SkipSubset(self.dataset_eval, self.args.val_skip) self.loader_eval = self._create_loader( self.dataset_eval, input_size=self.input_config['input_size'], batch_size=self.context.get_per_slot_batch_size(), is_training=False, use_prefetcher=self.args.prefetcher, interpolation=self.input_config['interpolation'], fill_color=self.input_config['fill_color'], mean=self.input_config['mean'], std=self.input_config['std'], num_workers=self.args.workers, distributed=self.args.distributed, pin_mem=self.args.pin_mem, anchor_labeler=self.labeler, ) self.evaluator = create_evaluator(self.args.dataset, self.loader_eval.dataset, pred_yxyx=False, context=self.context) return self.loader_eval
def create_datasets_and_loaders(args, model_config): input_config = resolve_input_config(args, model_config=model_config) dataset_train, dataset_eval = create_dataset(args.dataset, args.root) # setup labeler in loader/collate_fn if not enabled in the model bench labeler = None if not args.bench_labeler: labeler = AnchorLabeler(Anchors.from_config(model_config), model_config.num_classes, match_threshold=0.5) loader_train = create_loader( dataset_train, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, # color_jitter=args.color_jitter, # auto_augment=args.aa, interpolation=args.train_interpolation or input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, ) if args.val_skip > 1: dataset_eval = SkipSubset(dataset_eval, args.val_skip) loader_eval = create_loader( dataset_eval, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, ) evaluator = create_evaluator(args.dataset, loader_eval.dataset, distributed=args.distributed, pred_yxyx=False) return loader_train, loader_eval, evaluator
def create_datasets_and_loaders( args, model_config, transform_train_fn=None, transform_eval_fn=None, collate_fn=None, ): """ Setup datasets, transforms, loaders, evaluator. Args: args: Command line args / config for training model_config: Model specific configuration dict / struct transform_train_fn: Override default image + annotation transforms (see note in loaders.py) transform_eval_fn: Override default image + annotation transforms (see note in loaders.py) collate_fn: Override default fast collate function Returns: Train loader, validation loader, evaluator """ input_config = resolve_input_config(args, model_config=model_config) dataset_train, dataset_eval = create_dataset(args.dataset, args.root) # setup labeler in loader/collate_fn if not enabled in the model bench labeler = None if not args.bench_labeler: labeler = AnchorLabeler(Anchors.from_config(model_config), model_config.num_classes, match_threshold=0.5) loader_train = create_loader( dataset_train, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, # color_jitter=args.color_jitter, # auto_augment=args.aa, interpolation=args.train_interpolation or input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, transform_fn=transform_train_fn, collate_fn=collate_fn, ) if args.val_skip > 1: dataset_eval = SkipSubset(dataset_eval, args.val_skip) loader_eval = create_loader( dataset_eval, input_size=input_config['input_size'], batch_size=args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=input_config['interpolation'], fill_color=input_config['fill_color'], mean=input_config['mean'], std=input_config['std'], num_workers=args.workers, distributed=args.distributed, pin_mem=args.pin_mem, anchor_labeler=labeler, transform_fn=transform_eval_fn, collate_fn=collate_fn, ) evaluator = create_evaluator(args.dataset, loader_eval.dataset, distributed=args.distributed, pred_yxyx=False) return loader_train, loader_eval, evaluator