def process_frame(self, frame): """ Reconstructs holograms outputting results into 'wave' Parameters ---------- frame single frame (hologram) of the data """ if not self.params.precision: frame = frame.astype(np.float32) # size_x, size_y = self.params.out_shape frame_size = self.meta.partition_shape.sig sb_pos = self.params.sb_position aperture = self.task_data.aperture slice_fft = self.task_data.slice fft_frame = self.xp.fft.fft2(frame) / prod(frame_size) fft_frame = self.xp.roll(fft_frame, sb_pos, axis=(0, 1)) fft_frame = self.xp.fft.fftshift( self.xp.fft.fftshift(fft_frame)[slice_fft]) fft_frame = fft_frame * aperture wav = self.xp.fft.ifft2(fft_frame) * prod(frame_size) # FIXME check if result buffer with where='device' and export is faster # than exporting frame by frame, as implemented now. if self.meta.device_class == 'cuda': # That means xp is cupy wav = self.xp.asnumpy(wav) self.results.wave[:] = wav
def initialize(self, executor): self._header = h = executor.run_function(self._read_header) NY = int(h['NY']) NX = int(h['NX']) DP_SZ = int(h['DP_SZ']) self._image_count = NY * NX if self._nav_shape is None: self._nav_shape = (NY, NX) if self._sig_shape is None: self._sig_shape = (DP_SZ, DP_SZ) elif int(prod(self._sig_shape)) != (DP_SZ * DP_SZ): raise DataSetException("sig_shape must be of size: %s" % (DP_SZ * DP_SZ)) self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() self._shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)) self._meta = DataSetMeta( shape=self._shape, raw_dtype=np.dtype("u1"), sync_offset=self._sync_offset, image_count=self._image_count, ) self._filesize = executor.run_function(self._get_filesize) return self
def get_offsets_sizes(self, size: int) -> OffsetsSizes: """ Get file and frame offsets/sizes Parameters ---------- size : int len(memoryview) for the whole file Returns ------- slicing The file/frame slicing """ itemsize = np.dtype(self._native_dtype).itemsize assert self._frame_header % itemsize == 0 assert self._frame_footer % itemsize == 0 frame_size = int(prod(self._sig_shape)) frame_offset = self._frame_header // itemsize file_offset = self._file_header skip_end = 0 # cut off any extra data at the end of the file: if size % int(prod(self._sig_shape)): new_mmap_size = self.num_frames * ( (itemsize * frame_size) + self._frame_header + self._frame_footer ) skip_end = (size - file_offset) - new_mmap_size return OffsetsSizes( file_offset=file_offset, skip_end=skip_end, frame_offset=frame_offset, frame_size=frame_size, )
def _do_initialize(self): self._filesize = os.stat(self._path).st_size f = fileMRC(self._path) data = f.getMemmap() native_shape = data.shape dtype = data.dtype self._image_count = native_shape[0] if self._nav_shape is None: self._nav_shape = tuple((int(native_shape[0]), )) native_sig_shape = tuple(int(i) for i in f.gridSize if i != 1) if self._sig_shape is None: self._sig_shape = native_sig_shape elif int(prod(self._sig_shape)) != int(prod(native_sig_shape)): raise DataSetException("sig_shape must be of size: %s" % int(prod(native_sig_shape))) self._sig_dims = len(self._sig_shape) self._shape = Shape(self._nav_shape + self._sig_shape, sig_dims=self._sig_dims) self._nav_shape_product = self._shape.nav.size self._sync_offset_info = self.get_sync_offset_info() self._meta = DataSetMeta( shape=self._shape, raw_dtype=dtype, sync_offset=self._sync_offset, image_count=self._image_count, ) return self
def get_result_buffers(self): '' dtype = self.meta.input_dtype sigshape = tuple(self.meta.dataset_shape.sig) if self.meta.roi is not None: navsize = np.count_nonzero(self.meta.roi) else: navsize = prod(self.meta.dataset_shape.nav) warn_limit = 2**28 loaded_size = prod(sigshape) * navsize * np.dtype(dtype).itemsize if loaded_size > warn_limit: log.warning( "PickUDF is loading %s bytes, exceeding warning limit %s. " "Consider using or implementing an UDF to process data on the worker " "nodes instead." % (loaded_size, warn_limit)) # We are using a "single" buffer since we mostly load single frames. A # "sig" buffer would work as well, but would require a transpose to # accomodate multiple frames in the last and not first dimension. # A "nav" buffer would allocate a NaN-filled buffer for the whole dataset. return { 'intensity': self.buffer(kind='single', extra_shape=(navsize, ) + sigshape, dtype=dtype) }
def _get_size( self, io_max_size, udf: UDFProtocol, itemsize, approx_partition_shape: Shape, base_shape): """ Calculate the maximum tile size in bytes """ udf_method = udf.get_method() partition_size = itemsize * prod(tuple(approx_partition_shape)) partition_size_sig = itemsize * prod(tuple(approx_partition_shape.sig)) if udf_method == "frame": size = max(self._get_default_size(), partition_size_sig) elif udf_method == "partition": size = partition_size elif udf_method == "tile": # start with the UDF size preference: size = self._get_udf_size_pref(udf) # constrain to maximum read size size = min(size, io_max_size) # if the base_shape is larger than the current maximum size, # we need to increase the size: base_size = itemsize * prod(base_shape) size = max(base_size, size) return size
def test_prod(sequence, ref, typ): if typ is int: res = prod(sequence) assert res == ref assert isinstance(res, typ) else: with pytest.raises(typ): res = prod(sequence)
def new_for_partition(self, partition, roi): """ Return a new AuxBufferWrapper for a specific partition, slicing the data accordingly and reducing it to the selected roi. This assumes to be called on an AuxBufferWrapper that was not created by this method, that is, it should have global coordinates without having the ROI applied. """ # FIXME: right now creates a view for the partition slice, which # AFAIK means we serialize the whole array; we could optimize here # and only send over the partition slice. But maybe, later, when we # actually properly scatter and share data, this becomes obsolete anyways, # as we would scatter most likely for all partitions (to be flexible in node # assignment, for example for availability) assert self._data_coords_global ps = partition.slice.get(nav_only=True) buf = self.__class__(self._kind, self._extra_shape, self._dtype) if roi is not None: roi_part = roi.reshape(-1)[ps] new_data = self._data[ps][roi_part] else: new_data = self._data[ps] buf.set_buffer(new_data, is_global=False) buf.set_roi(roi) assert prod(new_data.shape) > 0 assert not buf._data_coords_global return buf
def zeros(self, size, dtype, alignment=4096): if dtype == object or prod(size) == 0: yield np.zeros(size, dtype=dtype) else: with self.empty(size, dtype, alignment) as res: res[:] = 0 yield res
def zeros_aligned(size: BufferSize, dtype: "nt.DTypeLike") -> np.ndarray: if dtype == object or prod(size) == 0: res = np.zeros(size, dtype=dtype) else: res = empty_aligned(size, dtype) res[:] = 0 return res
def empty_aligned(size: BufferSize, dtype: "nt.DTypeLike") -> np.ndarray: size_flat = prod(size) dtype = np.dtype(dtype) buf = _alloc_aligned(dtype.itemsize * size_flat) # _alloc_aligned may give us more memory (for alignment reasons), so crop it off the end: npbuf: np.ndarray = np.frombuffer(buf, dtype=dtype)[:size_flat] return npbuf.reshape(size)
def correct_dot_masks(masks, gain_map, excluded_pixels=None, allow_empty=False): mask_shape = masks.shape sig_shape = gain_map.shape masks = masks.reshape((-1, prod(sig_shape))) if excluded_pixels is not None: if is_sparse(masks): result = sparse.DOK(masks) else: result = masks.copy() desc = RepairDescriptor(sig_shape, excluded_pixels=excluded_pixels, allow_empty=allow_empty) for e, r, c in zip(desc.exclude_flat, desc.repair_flat, desc.repair_counts): result[:, e] = 0 rep = masks[:, e] / c # We have to loop because of sparse.pydata limitations for m in range(result.shape[0]): for rr in r[:c]: result[m, rr] = result[m, rr] + rep[m] if is_sparse(result): result = sparse.COO(result) else: result = masks result = result * gain_map.flatten() return result.reshape(mask_shape)
def empty(self, size, dtype, alignment=4096): size_flat = prod(size) dtype = np.dtype(dtype) with self.bytes(dtype.itemsize * size_flat, alignment) as buf: # self.bytes may give us more memory (for alignment reasons), so # crop it off the end: npbuf = np.frombuffer(buf, dtype=dtype)[:size_flat] yield npbuf.reshape(size)
def _get_io_max_size(self, dataset, approx_partition_shape, itemsize, need_decode): if need_decode: io_max_size = dataset.get_max_io_size() if io_max_size is None: io_max_size = 2**20 else: io_max_size = itemsize * prod(approx_partition_shape) return io_max_size
def initialize(self, executor): self._filesize = executor.run_function(self._get_filesize) if self._same_offset: metadata = executor.run_function(_get_metadata, self._get_files()[0]) self._offsets = { fn: metadata['offset'] for fn in self._get_files() } self._z_sizes = {fn: metadata['zsize'] for fn in self._get_files()} else: metadata = dict( zip( self._get_files(), executor.map(_get_metadata, self._get_files()), )) self._offsets = { fn: metadata[fn]['offset'] for fn in self._get_files() } self._z_sizes = { fn: metadata[fn]['zsize'] for fn in self._get_files() } self._image_count = sum(self._z_sizes.values()) if self._nav_shape is None: self._nav_shape = (sum(self._z_sizes.values()), ) native_sig_shape, native_dtype = executor.run_function( self._get_sig_shape_and_native_dtype) if self._sig_shape is None: self._sig_shape = tuple(native_sig_shape) elif int(prod(self._sig_shape)) != int(prod(native_sig_shape)): raise DataSetException("sig_shape must be of size: %s" % int(prod(native_sig_shape))) shape = self._nav_shape + self._sig_shape self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() self._meta = DataSetMeta( shape=Shape(shape, sig_dims=len(self._sig_shape)), raw_dtype=native_dtype, sync_offset=self._sync_offset, image_count=self._image_count, ) self._fileset = executor.run_function(self._get_fileset) return self
def _do_initialize(self): self._filenames = get_filenames(self._path) self._hdr_info = self._read_hdr_info() self._headers = [_read_file_header(path) for path in self._files()] header = self._headers[0] raw_frame_size = header['height'], header['width'] # frms6 frames are folded in a specific way, this is the shape after unfolding: frame_size = 2 * header['height'], header['width'] // 2 assert header['width'] % 2 == 0 hdr = self._get_hdr_info() bin_factor = hdr['readoutmode']['bin'] if bin_factor > 1: frame_size = (frame_size[0] * bin_factor, frame_size[1]) preferred_dtype = np.dtype("<u2") self._image_count = int(hdr['signalframes']) if self._nav_shape is None: self._nav_shape = tuple(hdr['stemimagesize']) if self._sig_shape is None: self._sig_shape = frame_size elif int(prod(self._sig_shape)) != int(prod(frame_size)): raise DataSetException("sig_shape must be of size: %s" % int(prod(frame_size))) self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() if self._enable_offset_correction: preferred_dtype = np.dtype("float32") self._meta = DataSetMeta( raw_dtype=np.dtype("<u2"), dtype=preferred_dtype, metadata={'raw_frame_size': raw_frame_size}, shape=Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)), sync_offset=self._sync_offset, image_count=self._image_count, ) self._dark_frame = self._get_dark_frame() self._gain_map = self._get_gain_map() self._total_filesize = sum( os.stat(path).st_size for path in self._files()) return self
def _do_initialize(self): header = self._header = _read_header(self._path, HEADER_FIELDS) self._image_offset = _get_image_offset(header) if header['version'] >= 5: # StreamPix version 6 # Timestamp = 4-byte unsigned long + 2-byte unsigned short (ms) # + 2-byte unsigned short (us) self._timestamp_micro = True else: # Older versions self._timestamp_micro = False try: dtype = np.dtype('uint%i' % header['bit_depth']) except TypeError: raise DataSetException("unsupported bit depth: %s" % header['bit_depth']) frame_size_bytes = header['width'] * header['height'] * dtype.itemsize self._footer_size = header['true_image_size'] - frame_size_bytes self._filesize = os.stat(self._path).st_size self._image_count = int( (self._filesize - self._image_offset) / header['true_image_size']) if self._sig_shape is None: self._sig_shape = tuple((header['height'], header['width'])) elif int(prod( self._sig_shape)) != (header['height'] * header['width']): raise DataSetException("sig_shape must be of size: %s" % (header['height'] * header['width'])) self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)) self._meta = DataSetMeta( shape=shape, raw_dtype=dtype, dtype=dtype, metadata=header, sync_offset=self._sync_offset, image_count=self._image_count, ) self._maybe_load_dark_gain() return self
def __init__(self, path, structure=None, io_backend=None): super().__init__(io_backend=io_backend) self._path = path self._dtype = structure.dtype self._structure = structure self._meta = DataSetMeta( shape=structure.shape, raw_dtype=np.dtype(structure.dtype), sync_offset=0, image_count=int(prod(structure.shape.nav)), ) self._executor = None
def _do_initialize(self): self._files = self._get_files() self._set_skip_frames_and_nav_shape() if self._sig_shape is None: self._sig_shape = (SECTOR_SIZE[0], NUM_SECTORS * SECTOR_SIZE[1]) elif int(prod(self._sig_shape)) != int( prod((SECTOR_SIZE[0], NUM_SECTORS * SECTOR_SIZE[1]))): raise DataSetException( "sig_shape must be of size: %s" % int(prod((SECTOR_SIZE[0], NUM_SECTORS * SECTOR_SIZE[1])))) self._image_count = _get_num_frames(self._get_syncer(do_sync=False)) self._set_sync_offset() self._get_syncer(do_sync=True) self._meta = DataSetMeta( shape=Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)), raw_dtype=np.dtype("uint16"), sync_offset=self._sync_offset, image_count=self._image_count, ) return self
def get_array_from_memview(self, mem: memoryview, slicing: OffsetsSizes): mem = mem[slicing.file_offset:-slicing.skip_end] res = np.frombuffer(mem, dtype="uint8") itemsize = np.dtype(self._native_dtype).itemsize sigsize = int(prod(self._sig_shape)) cutoff = 0 cutoff += (self.num_frames * itemsize * sigsize) res = res[:cutoff] return res.view(dtype=self._native_dtype).reshape( (self.num_frames, -1))[:, slicing.frame_offset:slicing.frame_offset + slicing.frame_size]
def get_num_partitions(self) -> int: """ Returns the number of partitions the dataset should be split into. The default implementation sizes partition such that they fit into 512MB of float data in memory, regardless of their native dtype. At least :code:`self._cores` partitions are created. """ partition_size_float_px = self.MAX_PARTITION_SIZE // 4 dataset_size_px = prod(self.shape) num: int = max(self._cores, dataset_size_px // partition_size_float_px) return num
def validate( self, shape: Tuple[int, ...], ds_sig_shape: Tuple[int, ...], size: int, io_max_size: int, itemsize: int, base_shape: Tuple[int, ...], corrections: Optional[CorrectionSet], ): sig_shape = shape[1:] # we need some wiggle room with the size, because there may be a harder # lower size value for some cases (for example HDF5, which overrides # some of the sizing negotiation we are doing here) if any(s > ps for s, ps in zip(sig_shape, ds_sig_shape)): raise ValueError("generated tileshape does not fit the partition") size_px = max(size, io_max_size) // itemsize if prod(shape) > size_px: message = "shape %r (%d) does not fit into size %d" % ( shape, prod(shape), size_px ) # The shape might be exceeded if dead pixel correction didn't find a # valid tiling scheme. In that case it falls back to by-frame processing. if ( corrections is not None and corrections.get_excluded_pixels() is not None and shape[0] == 1 ): warnings.warn(message) else: raise ValueError(message) for dim in range(len(base_shape)): if shape[dim] % base_shape[dim] != 0: raise ValueError( f"The tileshape {shape} is incompatible with base " f"shape {base_shape} in dimension {dim}." )
def _do_initialize(self): self._headers = self._preread_headers() self._files_sorted = list(sorted(self._files(), key=lambda f: f.fields['sequence_first_image'])) try: first_file = self._files_sorted[0] except IndexError: raise DataSetException("no files found") if self._nav_shape is None: hdr = read_hdr_file(self._path) self._nav_shape = nav_shape_from_hdr(hdr) if self._sig_shape is None: self._sig_shape = first_file.fields['image_size'] elif int(prod(self._sig_shape)) != int(prod(first_file.fields['image_size'])): raise DataSetException( "sig_shape must be of size: %s" % int(prod(first_file.fields['image_size'])) ) self._sig_dims = len(self._sig_shape) shape = Shape(self._nav_shape + self._sig_shape, sig_dims=self._sig_dims) dtype = first_file.fields['dtype'] self._total_filesize = sum( os.stat(path).st_size for path in self._filenames() ) self._sequence_start = first_file.fields['sequence_first_image'] self._files_sorted = list(sorted(self._files(), key=lambda f: f.fields['sequence_first_image'])) self._image_count = self._num_images() self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() self._meta = DataSetMeta( shape=shape, raw_dtype=dtype, sync_offset=self._sync_offset, image_count=self._image_count, ) return self
def initialize(self, executor): self._filesize = executor.run_function(self._get_filesize) if int(prod(self._sig_shape)) > int(self._filesize / np.dtype(self._dtype).itemsize): raise DataSetException( "sig_shape must be less than size: %s" % ( int(self._filesize / np.dtype(self._dtype).itemsize) ) ) self._image_count = int( self._filesize / ( np.dtype(self._dtype).itemsize * prod(self._sig_shape) ) ) self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() shape = Shape(self._nav_shape + self._sig_shape, sig_dims=self._sig_dims) self._meta = DataSetMeta( shape=shape, raw_dtype=np.dtype(self._dtype), sync_offset=self._sync_offset, image_count=self._image_count, ) return self
def _get_scale_factors(self, shape, containing_shape, size, min_factors=None): """ Generate scaling factors to scale `shape` up to `size` elements, while being constrained to `containing_shape`. """ log.debug( "_get_scale_factors in: shape=%r, containing_shape=%r, size=%r, min_factors=%r", shape, containing_shape, size, min_factors ) assert len(shape) == len(containing_shape) if min_factors is None: factors = [1] * len(shape) else: factors = list(min_factors) max_factors = tuple( cs // s for s, cs in zip(shape, containing_shape) ) prelim_shape = self._scale_base_shape(shape, factors) rest = size / prod(prelim_shape) if rest < 1: rest = 1 for idx in range(len(shape)): max_factor = max_factors[idx] factor = int(math.floor(rest * factors[idx])) if factor < factors[idx]: factor = factors[idx] if factor > max_factor: factor = max_factor factors[idx] = factor prelim_shape = self._scale_base_shape(shape, factors) rest = max(1, math.floor(size / prod(prelim_shape))) log.debug( "_get_scale_factors out: %r", factors, ) return factors
def _do_initialize(self): self._filesize = os.stat(self._path).st_size reader = SERFile(path=self._path, num_frames=None) with reader.get_handle() as f1: self._num_frames = f1.head['ValidNumberElements'] if f1.head['ValidNumberElements'] == 0: raise DataSetException("no data found in file") data, meta_data = f1.getDataset(0) dtype = f1._dictDataType[meta_data['DataType']] nav_dims = tuple( reversed([ int(dim['DimensionSize']) for dim in f1.head['Dimensions'] ]) ) self._image_count = int(self._num_frames) if self._nav_shape is None: self._nav_shape = nav_dims if self._sig_shape is None: self._sig_shape = tuple(data.shape) elif int(prod(self._sig_shape)) != int(prod(data.shape)): raise DataSetException( "sig_shape must be of size: %s" % int(prod(data.shape)) ) self._nav_shape_product = int(prod(self._nav_shape)) self._sync_offset_info = self.get_sync_offset_info() self._shape = Shape(self._nav_shape + self._sig_shape, sig_dims=len(self._sig_shape)) self._meta = DataSetMeta( shape=self._shape, raw_dtype=dtype, sync_offset=self._sync_offset, image_count=self._image_count, ) return self
def _get_tiles_by_block( self, tiling_scheme, open_files, read_ranges, read_dtype, native_dtype, decoder=None, corrections=None, sync_offset=0, ): if decoder is None: decoder = DtypeConversionDecoder() decode = decoder.get_decode( native_dtype=np.dtype(native_dtype), read_dtype=np.dtype(read_dtype), ) r_n_d = self._r_n_d = self.get_read_and_decode(decode) native_dtype = decoder.get_native_dtype(native_dtype, read_dtype) sig_dims = tiling_scheme.shape.sig.dims ds_shape = np.array(tiling_scheme.dataset_shape) largest_slice = sorted(( (prod(s_.shape), s_) for _, s_ in tiling_scheme.slices ), key=lambda x: x[0], reverse=True)[0][1] buf_shape = (tiling_scheme.depth,) + tuple(largest_slice.shape) need_clear = decoder.do_clear() slices = read_ranges[0] # Use NumPy prod for multidimensional array and axis parameter shape_prods = np.prod(slices[..., 1, :], axis=1, dtype=np.int64) ranges = read_ranges[1] scheme_indices = read_ranges[2] tile_block_size = len(tiling_scheme) with self._buffer_pool.empty(buf_shape, dtype=read_dtype) as out_decoded: out_decoded = out_decoded.reshape((-1,)) for block_idx in range(0, slices.shape[0], tile_block_size): block_ranges = ranges[block_idx:block_idx + tile_block_size] fill_factor, req_buf_size, min_per_file, max_per_file = block_get_min_fill_factor( block_ranges ) # TODO: if it makes sense, implement sparse variant # if req_buf_size > self._max_buffer_size or fill_factor < self._sparse_threshold: yield from self._read_block_dense( block_idx, tile_block_size, min_per_file, max_per_file, open_files, slices, ranges, scheme_indices, shape_prods, out_decoded, r_n_d, sig_dims, ds_shape, need_clear, native_dtype, corrections, )
def detect_params(cls, path, executor): if path.lower().endswith(".ser"): ds = cls(path) ds = ds.initialize(executor) return { "parameters": { "path": path, "nav_shape": tuple(ds.shape.nav), "sig_shape": tuple(ds.shape.sig), }, "info": { "image_count": int(prod(ds.shape.nav)), "native_sig_shape": tuple(ds.shape.sig), } } return False
def adjust_tileshape(self, tileshape, roi): chunks = self._chunks sig_shape = self.shape.sig if roi is not None: return (1, ) + sig_shape if chunks is not None and not _have_contig_chunks(chunks, self.shape): sig_chunks = chunks[-sig_shape.dims:] sig_ts = tileshape[-sig_shape.dims:] # if larger signal chunking is requested in the negotiation, # switch to full frames: if any(t > c for t, c in zip(sig_ts, sig_chunks)): # try to keep total tileshape size: tileshape_size = prod(tileshape) depth = max(1, tileshape_size // sig_shape.size) return (depth, ) + sig_shape else: # depth needs to be limited to prod(chunks.nav) return _tileshape_for_chunking(chunks, self.shape) return tileshape
def get_partition_shape(dataset_shape: Shape, target_size_items: int, min_num: Optional[int] = None) -> Tuple[int, ...]: """ Calculate partition shape for the given ``target_size_items`` Parameters ---------- dataset_shape "native" dataset shape target_size_items target partition size in number of items/pixels min_num minimum number of partitions """ sig_size = dataset_shape.sig.size current_p_shape: Tuple[int, ...] = () if min_num is None: min_num = 1 target_size_items = min(target_size_items, int(dataset_shape.size // min_num)) for dim in reversed(tuple(dataset_shape.nav)): proposed_shape = (dim, ) + current_p_shape proposed_size = prod(proposed_shape) * sig_size if proposed_size <= target_size_items: current_p_shape = proposed_shape else: overshoot = proposed_size / target_size_items last_size = max(1, int(dim // overshoot)) current_p_shape = (last_size, ) + current_p_shape break res = tuple( [1] * (len(dataset_shape.nav) - len(current_p_shape))) + current_p_shape return res