예제 #1
0
    def __init__(self, dataset: Union[Dataset, torch_data.Dataset], verbosity: int = 0, pool_size: int = 0,
                 slim: bool = False, rand_seed_log: str = None):
        """Given dataset and argument it with a calculate graph.

        Args:
            dataset: dataset to be augmented
            verbosity: Calculate graph's verbosity. 0: No log output; 1: Only give graph level log;
                >1: Give graph and node level log.
            pool_size: Calculate graph's pool size. 0 means don't use parallel.
            slim: If True, calculate graph use copy rather than deepcopy when setting value of Node's input.
            rand_seed_log: lmdb database where rand seeds saved
        """
        self._dataset = dataset

        self._rand_seed_log: Optional[str] = None
        self.lmdb_env: Optional[lmdb.Environment] = None
        self.rand_seed_log = rand_seed_log

        self.graph = Graph(verbosity=verbosity, pool_size=pool_size, slim=slim)
        self.build_graph()

        self.deep_prefix_id_ordered_nodes = self.graph.deep_prefix_id_ordered_nodes()

        multi_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes
                                 if isinstance(prefix_id_node[1]._fct, MultiMap)]
        self.multi_nodes_prefix_id, self.multi_nodes = zip(*multi_prefix_id_nodes) if multi_prefix_id_nodes \
            else ([], [])

        rand_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes
                                if isinstance(prefix_id_node[1]._fct, RandMap)]
        self.rand_nodes_prefix_id, self.rand_nodes = zip(*rand_prefix_id_nodes) if rand_prefix_id_nodes else \
            ([], [])

        self.multi_factors = self._get_multi_factors()
        self.divide_factors = self._get_divide_factors()
예제 #2
0
def build_pad_crop_mirror(crop_size, is_multi_mirror=False):
    graph = Graph(slim=True)

    graph.add_node(pad_img_label,
                   inputs=['img', 'label', {
                       'pad_img_to': crop_size
                   }],
                   kwargs={
                       'pad_location': 'center',
                       'ignore_label': VOC.ignore_label
                   },
                   outputs=['padded_img', 'padded_label'])

    # * Rand crop
    rand_crop = RandCrop(crop_size)
    graph.add_node(rand_crop,
                   inputs=['padded_img', 'padded_label'],
                   outputs=['cropped_img', 'cropped_label'])

    mirror = RandMirror() if not is_multi_mirror else MultiMirror()
    graph.add_node(mirror,
                   inputs=['cropped_img', 'cropped_label'],
                   outputs=['mirrored_img', 'mirrored_label'])

    return graph
예제 #3
0
def build_voc_raw_img_float_centralize_chw():
    voc_raw_img_float_centralize_chw = Graph(slim=True)

    voc_raw_img_float_centralize_chw.add_node(int_img2float32_img,
                                              inputs=['img'],
                                              outputs=['float_img'])

    voc_raw_img_float_centralize_chw.add_node(
        centralize,
        inputs=['float_img', {
            'mean': VOC.mean_bgr
        }],
        outputs=['centralized_img'])
    voc_raw_img_float_centralize_chw.add_node(HWC2CHW,
                                              inputs=['centralized_img'],
                                              outputs=['CHW_img'])

    return voc_raw_img_float_centralize_chw
예제 #4
0
def build_collect_and_attach():
    collect_and_attach = Graph(slim=True)

    collect_and_attach.add_node(collect_example,
                                args=['img_id', 'img', 'label'],
                                outputs=['collected'])

    collect_and_attach.add_node(attach_cls,
                                args=['collected'],
                                outputs=['attached'])

    return collect_and_attach
예제 #5
0
def build_jitter_scale(is_color_jitter, scale_factors):
    graph = Graph(slim=True)

    rand_color_jitter = RandColorJitter(
    ) if is_color_jitter else RandColorJitter(jitter_prob=0.0)
    graph.add_node(rand_color_jitter, inputs=['img'], outputs=['jitter_img'])

    # * Rand scale
    rand_scale = RandScale(scale_factors)
    graph.add_node(rand_scale,
                   inputs=['jitter_img', 'label'],
                   outputs=['scaled_img', 'scaled_label'])

    return graph
예제 #6
0
class DataAuger(Dataset):

    def __init__(self, dataset: Union[Dataset, torch_data.Dataset], verbosity: int = 0, pool_size: int = 0,
                 slim: bool = False, rand_seed_log: str = None):
        """Given dataset and argument it with a calculate graph.

        Args:
            dataset: dataset to be augmented
            verbosity: Calculate graph's verbosity. 0: No log output; 1: Only give graph level log;
                >1: Give graph and node level log.
            pool_size: Calculate graph's pool size. 0 means don't use parallel.
            slim: If True, calculate graph use copy rather than deepcopy when setting value of Node's input.
            rand_seed_log: lmdb database where rand seeds saved
        """
        self._dataset = dataset

        self._rand_seed_log: Optional[str] = None
        self.lmdb_env: Optional[lmdb.Environment] = None
        self.rand_seed_log = rand_seed_log

        self.graph = Graph(verbosity=verbosity, pool_size=pool_size, slim=slim)
        self.build_graph()

        self.deep_prefix_id_ordered_nodes = self.graph.deep_prefix_id_ordered_nodes()

        multi_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes
                                 if isinstance(prefix_id_node[1]._fct, MultiMap)]
        self.multi_nodes_prefix_id, self.multi_nodes = zip(*multi_prefix_id_nodes) if multi_prefix_id_nodes \
            else ([], [])

        rand_prefix_id_nodes = [(prefix_id_node) for prefix_id_node in self.deep_prefix_id_ordered_nodes
                                if isinstance(prefix_id_node[1]._fct, RandMap)]
        self.rand_nodes_prefix_id, self.rand_nodes = zip(*rand_prefix_id_nodes) if rand_prefix_id_nodes else \
            ([], [])

        self.multi_factors = self._get_multi_factors()
        self.divide_factors = self._get_divide_factors()

    def build_graph(self):
        """Build self.graph here. Need to be overloaded"""
        raise NotImplementedError

    def _get_multi_factors(self):
        multi_factors = [len(self._dataset)]

        for node in self.multi_nodes:
            multi_factors.append(len(node._fct))

        return multi_factors

    def _get_divide_factors(self):
        divide_factors = self.multi_factors[::-1][:-1]
        divide_factors = list(accumulate(divide_factors, lambda x, y: x * y))
        return divide_factors[::-1] + [1]

    @property
    def dataset(self):
        """Get auger's dataset"""
        return self._dataset

    @dataset.setter
    def dataset(self, value):
        """Reset auger's dataset"""
        self._dataset = value
        self.multi_factors = self._get_multi_factors()

    @property
    def multi_factor(self):
        """Return the number of one example were auger to"""
        return reduce(lambda x, y: x * y, self.multi_factors[1:])

    def __len__(self):
        """Return the size of augmented dataset"""
        return reduce(lambda x, y: x * y, self.multi_factors)

    def load_indices(self, idx):
        """Set every multi node's idx and return dataset's idx with auger's idx"""
        dataset_idx = idx // self.divide_factors[0]
        idx %= self.divide_factors[0]

        for multi_node, divide_factor in zip(self.multi_nodes, self.divide_factors[1:]):
            multi_node._fct.output_index = idx // divide_factor
            idx %= divide_factor

        return dataset_idx

    def calculate_indices(self, idx):
        """Return dataset's idx and every multi node's idx with auger's idx"""
        dataset_idx = idx // self.divide_factors[0]
        idx %= self.divide_factors[0]

        node_indices = {}
        for prefix_id, divide_factor in zip(self.multi_nodes_prefix_id, self.divide_factors[1:]):
            node_indices[prefix_id] = idx // divide_factor
            idx %= divide_factor

        return dataset_idx, node_indices

    @property
    def rand_seeds(self):
        """Return rand seed of every rand node"""
        rand_seeds = {}

        for prefix_id, node in zip(self.rand_nodes_prefix_id ,self.rand_nodes):
            rand_seeds[prefix_id] = node._fct.rand_seed
        return rand_seeds

    def load_rand_seeds(self, rand_seeds):
        """Load rand seed from rand_seeds"""
        for prefix_id, node in zip(self.rand_nodes_prefix_id ,self.rand_nodes):
            node._fct.rand_seed = rand_seeds[prefix_id]

    @property
    def rand_seed_log(self):
        return self._rand_seed_log

    @rand_seed_log.setter
    def rand_seed_log(self, rand_seed_log):
        if (not isinstance(rand_seed_log, str)) and (rand_seed_log is not None):
            raise ValueError(f"rand_seed_log = {rand_seed_log} must be str or None")

        self._rand_seed_log = rand_seed_log

        if self.lmdb_env is not None:
            self.lmdb_env.close()
            self.lmdb_env = None

        if rand_seed_log is not None:
            self.lmdb_env = lmdb.open(rand_seed_log, meminit=False, map_size=2147483648, max_spare_txns=64,
                                      sync=False, metasync=False, lock=True)

    def __del__(self):
        self.rand_seed_log = None

    def get_item(self, idx):
        example = self._dataset[self.load_indices(idx)]

        if self.lmdb_env is not None:
            with self.lmdb_env.begin() as txn:
                rand_seeds = txn.get(str(idx).encode())
            if rand_seeds is not None:
                rand_seeds = pickle.loads(rand_seeds)
                self.load_rand_seeds(rand_seeds)

        ret = self.graph.calculate(data={'example': example})

        if self.lmdb_env is not None and rand_seeds is None:
            with self.lmdb_env.begin(write=True) as txn:
                txn.put(str(idx).encode(), pickle.dumps(self.rand_seeds))

        return ret

    def __repr__(self):
        return f"DataAuger <{self.__class__.__name__}>:\n" \
                + indent(f"graph: {self.graph}") + "\n" \
                + indent(f"dataset: {self.dataset}") + "\n" \
                + indent(f"#DataAuger: {len(self)}")