def __call__(self, cls): if isinstance(cls, Type): # type: ignore if not issubclass(cls, IterDataPipe): raise TypeError( '`functional_datapipe` can only decorate IterDataPipe') # with non_deterministic decorator else: if not isinstance(cls, non_deterministic) and \ not (hasattr(cls, '__self__') and isinstance(cls.__self__, non_deterministic)): raise TypeError( '`functional_datapipe` can only decorate IterDataPipe') IterDataPipe.register_datapipe_as_function(self.name, cls) return cls
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) if DILL_AVAILABLE: dill_function = dill.dumps(self.filter_fn) else: dill_function = self.filter_fn state = (self.datapipe, dill_function, self.args, self.kwargs, self.drop_empty_batches) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) state = ( self.main_datapipe, self.num_instances, self.buffer_size, ) return state
def __call__(self, cls): if issubclass(cls, IterDataPipe): if isinstance(cls, Type): # type: ignore[arg-type] if not isinstance(cls, _DataPipeMeta): raise TypeError( '`functional_datapipe` can only decorate IterDataPipe') # with non_deterministic decorator else: if not isinstance(cls, non_deterministic) and \ not (hasattr(cls, '__self__') and isinstance(cls.__self__, non_deterministic)): raise TypeError( '`functional_datapipe` can only decorate IterDataPipe') IterDataPipe.register_datapipe_as_function( self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing) elif issubclass(cls, MapDataPipe): MapDataPipe.register_datapipe_as_function(self.name, cls) return cls
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) serialized_fn_with_method = serialize_fn(self.classifier_fn) state = ( self.main_datapipe, self.num_instances, self.buffer_size, serialized_fn_with_method, self.drop_none, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) if DILL_AVAILABLE: dill_function = dill.dumps(self.fn) else: dill_function = self.fn state = ( self.datapipe, dill_function, self.input_col, self.output_col, ) return state
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) if DILL_AVAILABLE: dill_function = dill.dumps(self.classifier_fn) else: dill_function = self.classifier_fn state = ( self.main_datapipe, self.num_instances, self.buffer_size, dill_function, self.drop_none, ) return state
def __init__(self, datapipe: IterDataPipe, batch_size: int, drop_last: bool = False, unbatch_level: int = 0, ) -> None: assert batch_size > 0, "Batch size is required to be larger than 0!" super().__init__() if unbatch_level == 0: self.datapipe = datapipe else: self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level) self.unbatch_level = unbatch_level self.batch_size = batch_size self.drop_last = drop_last self.length = None self.wrapper_class = DataChunk
def __getstate__(self): if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(self) if DILL_AVAILABLE: dill_function = dill.dumps(self.group_key_fn) else: dill_function = self.group_key_fn state = ( self.datapipe, dill_function, self.buffer_size, self.group_size, self.guaranteed_group_size, self.drop_remaining, ) return state
def list_connected_datapipes(scan_obj, exclude_primitive): f = io.BytesIO() p = pickle.Pickler( f ) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is def stub_pickler(obj): return stub_unpickler, () captured_connections = [] def getstate_hook(obj): state = {} for k, v in obj.__dict__.items(): if callable(v) or isinstance(v, PRIMITIVE): continue state[k] = v return state def reduce_hook(obj): if obj == scan_obj: raise NotImplementedError else: captured_connections.append(obj) return stub_unpickler, () try: IterDataPipe.set_reduce_ex_hook(reduce_hook) if exclude_primitive: IterDataPipe.set_getstate_hook(getstate_hook) p.dump(scan_obj) except AttributeError: # unpickable DataPipesGraph pass # TODO(VitalyFedyunin): We need to tight this requirement after migrating from old DataLoader finally: IterDataPipe.set_reduce_ex_hook(None) if exclude_primitive: IterDataPipe.set_getstate_hook(None) return captured_connections
def __call__(self, cls): if not (isinstance(cls, non_deterministic) or issubclass(cls, IterDataPipe)): raise Exception('Can only decorate IterDataPipe') IterDataPipe.register_datapipe_as_function(self.name, cls) return cls