class DataRegime(object): def __init__(self, regime, defaults={}): self.regime = Regime(regime, deepcopy(defaults)) self.epoch = 0 self.steps = None self.get_loader(True) def get_setting(self): setting = self.regime.setting loader_setting = { k: v for k, v in setting.items() if k in _DATALOADER_ARGS } data_setting = {k: v for k, v in setting.items() if k in _DATA_ARGS} transform_setting = { k: v for k, v in setting.items() if k in _TRANSFORM_ARGS } other_setting = {k: v for k, v in setting.items() if k in _OTHER_ARGS} transform_setting.setdefault('transform_name', data_setting['name']) return { 'data': data_setting, 'loader': loader_setting, 'transform': transform_setting, 'other': other_setting } def get(self, key, default=None): return self.regime.setting.get(key, default) def get_loader(self, force_update=False, override_settings=None, subset_indices=None): if force_update or self.regime.update(self.epoch, self.steps): setting = self.get_setting() if override_settings is not None: setting.update(override_settings) self._transform = get_transform(**setting['transform']) setting['data'].setdefault('transform', self._transform) self._data = get_dataset(**setting['data']) if subset_indices is not None: self._data = Subset(self._data, subset_indices) if setting['other'].get('distributed', False): setting['loader']['sampler'] = DistributedSampler(self._data) setting['loader']['shuffle'] = None # pin-memory currently broken for distributed setting['loader']['pin_memory'] = False self._sampler = setting['loader'].get('sampler', None) self._loader = torch.utils.data.DataLoader(self._data, **setting['loader']) return self._loader def set_epoch(self, epoch): self.epoch = epoch if self._sampler is not None and hasattr(self._sampler, 'set_epoch'): self._sampler.set_epoch(epoch) def __len__(self): return len(self._data)
class DataRegime(object): def __init__(self, regime, defaults={}): self.regime = Regime(regime, defaults) self.epoch = 0 self.steps = None self.get_loader(True) def get_setting(self): setting = self.regime.setting loader_setting = {k: v for k, v in setting.items() if k in _DATALOADER_ARGS} data_setting = {k: v for k, v in setting.items() if k in _DATA_ARGS} transform_setting = { k: v for k, v in setting.items() if k in _TRANSFORM_ARGS} other_setting = {k: v for k, v in setting.items() if k in _OTHER_ARGS} transform_setting.setdefault('transform_name', data_setting['name']) return {'data': data_setting, 'loader': loader_setting, 'transform': transform_setting, 'other': other_setting} def get_loader(self, force_update=False): if force_update or self.regime.update(self.epoch, self.steps): setting = self.get_setting() self._transform = get_transform(**setting['transform']) setting['data'].setdefault('transform', self._transform) self._data = get_dataset(**setting['data']) if setting['other'].get('distributed', False): setting['loader']['sampler'] = DistributedSampler(self._data) setting['loader']['shuffle'] = None # pin-memory currently broken for distributed setting['loader']['pin_memory'] = False if setting['other'].get('duplicates', 0) > 1: setting['loader']['shuffle'] = None sampler = setting['loader'].get( 'sampler', RandomSampler(self._data)) setting['loader']['sampler'] = DuplicateBatchSampler(sampler, setting['loader']['batch_size'], duplicates=setting['other']['duplicates'], drop_last=setting['loader'].get('drop_last', False)) self._sampler = setting['loader'].get('sampler', None) self._loader = torch.utils.data.DataLoader( self._data, **setting['loader']) if setting['other'].get('duplicates', 0) > 1: self._loader.batch_sampler = self._sampler return self._loader def set_epoch(self, epoch): self.epoch = epoch if self._sampler is not None and hasattr(self._sampler, 'set_epoch'): self._sampler.set_epoch(epoch) def __len__(self): return len(self._data)