def __init__(self, dataset: Union[Dataset, torch_data.Dataset], verbosity: int = 0, pool_size: int = 0, slim: bool = False, rand_seed_log: str = None): """Given dataset and argument it with a calculate graph. Args: dataset: dataset to be augmented verbosity: Calculate graph's verbosity. 0: No log output; 1: Only give graph level log; >1: Give graph and node level log. pool_size: Calculate graph's pool size. 0 means don't use parallel. slim: If True, calculate graph use copy rather than deepcopy when setting value of Node's input. rand_seed_log: lmdb database where rand seeds saved """ self._dataset = dataset self._rand_seed_log: Optional[str] = None self.lmdb_env: Optional[lmdb.Environment] = None self.rand_seed_log = rand_seed_log self.graph = Graph(verbosity=verbosity, pool_size=pool_size, slim=slim) self.build_graph() self.deep_prefix_id_ordered_nodes = self.graph.deep_prefix_id_ordered_nodes() multi_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes if isinstance(prefix_id_node[1]._fct, MultiMap)] self.multi_nodes_prefix_id, self.multi_nodes = zip(*multi_prefix_id_nodes) if multi_prefix_id_nodes \ else ([], []) rand_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes if isinstance(prefix_id_node[1]._fct, RandMap)] self.rand_nodes_prefix_id, self.rand_nodes = zip(*rand_prefix_id_nodes) if rand_prefix_id_nodes else \ ([], []) self.multi_factors = self._get_multi_factors() self.divide_factors = self._get_divide_factors()
def build_pad_crop_mirror(crop_size, is_multi_mirror=False): graph = Graph(slim=True) graph.add_node(pad_img_label, inputs=['img', 'label', { 'pad_img_to': crop_size }], kwargs={ 'pad_location': 'center', 'ignore_label': VOC.ignore_label }, outputs=['padded_img', 'padded_label']) # * Rand crop rand_crop = RandCrop(crop_size) graph.add_node(rand_crop, inputs=['padded_img', 'padded_label'], outputs=['cropped_img', 'cropped_label']) mirror = RandMirror() if not is_multi_mirror else MultiMirror() graph.add_node(mirror, inputs=['cropped_img', 'cropped_label'], outputs=['mirrored_img', 'mirrored_label']) return graph
def build_voc_raw_img_float_centralize_chw(): voc_raw_img_float_centralize_chw = Graph(slim=True) voc_raw_img_float_centralize_chw.add_node(int_img2float32_img, inputs=['img'], outputs=['float_img']) voc_raw_img_float_centralize_chw.add_node( centralize, inputs=['float_img', { 'mean': VOC.mean_bgr }], outputs=['centralized_img']) voc_raw_img_float_centralize_chw.add_node(HWC2CHW, inputs=['centralized_img'], outputs=['CHW_img']) return voc_raw_img_float_centralize_chw
def build_collect_and_attach(): collect_and_attach = Graph(slim=True) collect_and_attach.add_node(collect_example, args=['img_id', 'img', 'label'], outputs=['collected']) collect_and_attach.add_node(attach_cls, args=['collected'], outputs=['attached']) return collect_and_attach
def build_jitter_scale(is_color_jitter, scale_factors): graph = Graph(slim=True) rand_color_jitter = RandColorJitter( ) if is_color_jitter else RandColorJitter(jitter_prob=0.0) graph.add_node(rand_color_jitter, inputs=['img'], outputs=['jitter_img']) # * Rand scale rand_scale = RandScale(scale_factors) graph.add_node(rand_scale, inputs=['jitter_img', 'label'], outputs=['scaled_img', 'scaled_label']) return graph
class DataAuger(Dataset): def __init__(self, dataset: Union[Dataset, torch_data.Dataset], verbosity: int = 0, pool_size: int = 0, slim: bool = False, rand_seed_log: str = None): """Given dataset and argument it with a calculate graph. Args: dataset: dataset to be augmented verbosity: Calculate graph's verbosity. 0: No log output; 1: Only give graph level log; >1: Give graph and node level log. pool_size: Calculate graph's pool size. 0 means don't use parallel. slim: If True, calculate graph use copy rather than deepcopy when setting value of Node's input. rand_seed_log: lmdb database where rand seeds saved """ self._dataset = dataset self._rand_seed_log: Optional[str] = None self.lmdb_env: Optional[lmdb.Environment] = None self.rand_seed_log = rand_seed_log self.graph = Graph(verbosity=verbosity, pool_size=pool_size, slim=slim) self.build_graph() self.deep_prefix_id_ordered_nodes = self.graph.deep_prefix_id_ordered_nodes() multi_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes if isinstance(prefix_id_node[1]._fct, MultiMap)] self.multi_nodes_prefix_id, self.multi_nodes = zip(*multi_prefix_id_nodes) if multi_prefix_id_nodes \ else ([], []) rand_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes if isinstance(prefix_id_node[1]._fct, RandMap)] self.rand_nodes_prefix_id, self.rand_nodes = zip(*rand_prefix_id_nodes) if rand_prefix_id_nodes else \ ([], []) self.multi_factors = self._get_multi_factors() self.divide_factors = self._get_divide_factors() def build_graph(self): """Build self.graph here. Need to be overloaded""" raise NotImplementedError def _get_multi_factors(self): multi_factors = [len(self._dataset)] for node in self.multi_nodes: multi_factors.append(len(node._fct)) return multi_factors def _get_divide_factors(self): divide_factors = self.multi_factors[::-1][:-1] divide_factors = list(accumulate(divide_factors, lambda x, y: x * y)) return divide_factors[::-1] + [1] @property def dataset(self): """Get auger's dataset""" return self._dataset @dataset.setter def dataset(self, value): """Reset auger's dataset""" self._dataset = value self.multi_factors = self._get_multi_factors() @property def multi_factor(self): """Return the number of one example were auger to""" return reduce(lambda x, y: x * y, self.multi_factors[1:]) def __len__(self): """Return the size of augmented dataset""" return reduce(lambda x, y: x * y, self.multi_factors) def load_indices(self, idx): """Set every multi node's idx and return dataset's idx with auger's idx""" dataset_idx = idx // self.divide_factors[0] idx %= self.divide_factors[0] for multi_node, divide_factor in zip(self.multi_nodes, self.divide_factors[1:]): multi_node._fct.output_index = idx // divide_factor idx %= divide_factor return dataset_idx def calculate_indices(self, idx): """Return dataset's idx and every multi node's idx with auger's idx""" dataset_idx = idx // self.divide_factors[0] idx %= self.divide_factors[0] node_indices = {} for prefix_id, divide_factor in zip(self.multi_nodes_prefix_id, self.divide_factors[1:]): node_indices[prefix_id] = idx // divide_factor idx %= divide_factor return dataset_idx, node_indices @property def rand_seeds(self): """Return rand seed of every rand node""" rand_seeds = {} for prefix_id, node in zip(self.rand_nodes_prefix_id ,self.rand_nodes): rand_seeds[prefix_id] = node._fct.rand_seed return rand_seeds def load_rand_seeds(self, rand_seeds): """Load rand seed from rand_seeds""" for prefix_id, node in zip(self.rand_nodes_prefix_id ,self.rand_nodes): node._fct.rand_seed = rand_seeds[prefix_id] @property def rand_seed_log(self): return self._rand_seed_log @rand_seed_log.setter def rand_seed_log(self, rand_seed_log): if (not isinstance(rand_seed_log, str)) and (rand_seed_log is not None): raise ValueError(f"rand_seed_log = {rand_seed_log} must be str or None") self._rand_seed_log = rand_seed_log if self.lmdb_env is not None: self.lmdb_env.close() self.lmdb_env = None if rand_seed_log is not None: self.lmdb_env = lmdb.open(rand_seed_log, meminit=False, map_size=2147483648, max_spare_txns=64, sync=False, metasync=False, lock=True) def __del__(self): self.rand_seed_log = None def get_item(self, idx): example = self._dataset[self.load_indices(idx)] if self.lmdb_env is not None: with self.lmdb_env.begin() as txn: rand_seeds = txn.get(str(idx).encode()) if rand_seeds is not None: rand_seeds = pickle.loads(rand_seeds) self.load_rand_seeds(rand_seeds) ret = self.graph.calculate(data={'example': example}) if self.lmdb_env is not None and rand_seeds is None: with self.lmdb_env.begin(write=True) as txn: txn.put(str(idx).encode(), pickle.dumps(self.rand_seeds)) return ret def __repr__(self): return f"DataAuger <{self.__class__.__name__}>:\n" \ + indent(f"graph: {self.graph}") + "\n" \ + indent(f"dataset: {self.dataset}") + "\n" \ + indent(f"#DataAuger: {len(self)}")