コード例 #1
0
ファイル: test_normalize.py プロジェクト: matthew-frank/DALI
    def __init__(self, device, batch_size, dims, axes, axis_names, batch=False,
                 out_type=None, in_type=None, shift=None, scale=None,
                 num_threads=3, device_id=0, num_gpus=1):
        super(NormalizePipeline, self).__init__(
            batch_size, num_threads, device_id, seed=7865,
            exec_async=False, exec_pipelined=False)
        common_args = {
            "device": device,
            "axes": axes,
            "axis_names": axis_names,
            "batch": batch,
            "dtype": dali_type(out_type),
            "shift": shift,
            "scale": scale
        }
        self.in_type = in_type
        self.out_type = out_type
        self.device = device
        self.input = ops.ExternalSource()
        self.add_layout = None
        if axis_names is not None:
            layout = ''
            for i in range(dims):
                layout += chr(ord('a') + i)
            self.add_layout = ops.Reshape(layout=layout)
        self.batch = batch
        self.dims = dims
        self.has_axes = axes is not None or axis_names is not None
        self.scale = scale
        self.shift = shift
        self.is_integral = out_type is not None and out_type is not np.float32

        if axis_names is not None:
            axes = []
            for a in axis_names:
                axes.append(ord(a) - ord('a'))

        self.axes = axes
        self.axis_names = axis_names
        self.ddof = 2 if axes is not None and len(axes) > 0 else 0
        self.eps = 0.25

        self.mean = ops.PythonFunction(function=custom_mean(batch, axes), batch_processing=True)
        self.stddev = ops.PythonFunction(function=custom_stddev(batch, axes), batch_processing=True)
        self.normalize = ops.Normalize(**common_args, ddof=self.ddof)
        self.scalar_mean = ops.Normalize(**common_args, mean=1, ddof=self.ddof, epsilon=self.eps)
        self.scalar_stddev = ops.Normalize(**common_args, stddev=2, epsilon=self.eps)
        self.scalar_params = ops.Normalize(**common_args, mean=1, stddev=2)
コード例 #2
0
ファイル: debug_without_sacred.py プロジェクト: gddcx/yolov1
 def __init__(self, params, device_id, files, labels):
     super().__init__(params.batch_size,
                      params.num_gpus * 8,
                      device_id,
                      seed=params.seed)
     # file_root有坑,并不是文件夹名字就是label,按照文件夹顺序(1, 10, 11, 2, 20, 21, ...)分别给与0,1,2,3,4...标签
     self.input = ops.FileReader(files=files,
                                 labels=labels,
                                 random_shuffle=True)
     self.decocer = ops.ImageDecoder(device='mixed', output_type=types.RGB)
     self.resize = ops.Resize(device='gpu', resize_shorter=224)
     self.pos_rng_x = ops.random.Uniform(range=(0.0, 1.0))
     self.pos_rng_y = ops.random.Uniform(range=(0.0, 1.0))
     self.crop = ops.Crop(device='gpu', crop_h=224, crop_w=224)
     self.flip = ops.Flip(device='gpu')
     self.coinflip = ops.random.CoinFlip(probability=0.5)
     self.hsv = ops.Hsv(device='gpu')
     self.saturation = ops.random.Uniform(range=(0.8, 1.0))
     self.value = ops.random.Uniform(range=(0.8, 1.0))
     mean = torch.Tensor(params.mean).unsqueeze(0).unsqueeze(0) * 255
     std = torch.Tensor(params.std).unsqueeze(0).unsqueeze(0) * 255
     self.normalize = ops.Normalize(axes=[0, 1],
                                    mean=mean,
                                    stddev=std,
                                    device='gpu',
                                    batch=False)
     self.transpose = ops.Transpose(device='gpu', perm=[2, 0, 1])
コード例 #3
0
ファイル: data_loading.py プロジェクト: redsphinx/3tconv
    def __init__(self,
                 batch_size,
                 num_threads=6,
                 device_id=0,
                 file_root='',
                 shuffle=False,
                 sequence_length=30,
                 step=-1,
                 stride=1,
                 initial_fill=1024,
                 seed=0,
                 channels=3):

        super(VideoPipeFileRoot, self).__init__(batch_size,
                                                num_threads,
                                                device_id,
                                                seed=seed)

        self.input = ops.VideoReader(device='gpu',
                                     file_root=file_root,
                                     sequence_length=sequence_length,
                                     step=step,
                                     stride=stride,
                                     shard_id=0,
                                     num_shards=1,
                                     random_shuffle=shuffle,
                                     initial_fill=initial_fill,
                                     channels=channels)

        self.normalize = ops.Normalize(device='gpu')
コード例 #4
0
ファイル: test_torch_pipeline_rnnt.py プロジェクト: xvdp/DALI
    def __init__(self,
                 device_id,
                 n_devices,
                 file_root,
                 file_list,
                 batch_size,
                 sample_rate=16000,
                 window_size=.02,
                 window_stride=.01,
                 nfeatures=64,
                 nfft=512,
                 frame_splicing_factor=3,
                 silence_threshold=-80,
                 dither=.00001,
                 preemph_coeff=.97,
                 lowfreq=0.0,
                 highfreq=0.0,
                 num_threads=1):
        super().__init__(batch_size, num_threads, device_id, seed=42)

        self.dither = dither
        self.frame_splicing_factor = frame_splicing_factor

        self.read = ops.readers.File(file_root=file_root, file_list=file_list, device="cpu",
                                     shard_id=device_id, num_shards=n_devices)

        self.decode = ops.AudioDecoder(device="cpu", dtype=types.FLOAT, downmix=True)

        self.normal_distribution = ops.random.Normal(device="cpu")

        self.preemph = ops.PreemphasisFilter(preemph_coeff=preemph_coeff)

        self.spectrogram = ops.Spectrogram(device="cpu", nfft=nfft,
                                           window_length=window_size * sample_rate,
                                           window_step=window_stride * sample_rate)

        self.mel_fbank = ops.MelFilterBank(device="cpu", sample_rate=sample_rate, nfilter=nfeatures,
                                           normalize=True, freq_low=lowfreq, freq_high=highfreq)

        self.log_features = ops.ToDecibels(device="cpu", multiplier=np.log(10), reference=1.0,
                                           cutoff_db=-80)

        self.get_shape = ops.Shapes(device="cpu")

        self.normalize = ops.Normalize(axes=[0], device="cpu")

        self.splicing_transpose = ops.Transpose(device="cpu", perm=[1, 0])
        self.splicing_reshape = ops.Reshape(device="cpu", rel_shape=[-1, frame_splicing_factor])
        self.splicing_pad = ops.Pad(axes=[0], fill_value=0, align=frame_splicing_factor, shape=[1],
                                    device="cpu")

        self.get_nonsilent_region = ops.NonsilentRegion(device="cpu", cutoff_db=silence_threshold)
        self.trim_silence = ops.Slice(device="cpu", axes=[0])
        self.to_float = ops.Cast(dtype=types.FLOAT)
コード例 #5
0
ファイル: modules.py プロジェクト: jesslynsepthiaa/vortex
    def __init__(self,
                 input_size: int,
                 scaler: Union[int, float] = 255,
                 mean: List[float] = [0., 0., 0.],
                 std: List[float] = [1., 1., 1.],
                 image_pad_value: Union[int, float] = 0,
                 labels_pad_value: Union[int, float] = -99,
                 normalize: bool = True):
        """Initialization

        Args:
            input_size (int): Target size of image resize
            scaler (Union[int,float], optional): The scaling factor applied to the input pixel value. Defaults to 255.
            mean (List[float], optional): Mean pixel values for image normalization. Defaults to [0.,0.,0.].
            std (List[float], optional): Standard deviation values for image normalization. Defaults to [1.,1.,1.].
            image_pad_value (Union[int,float], optional): Values of the color to pad the image to square.. Defaults to 0.
            labels_pad_value (Union[int,float], optional): Values used to pad the labels information so it have same dimension. Will be deleted on the dataloader. Defaults to -99.
            normalize (bool, optional): Will apply normalization if set to True. Defaults to True.
        """

        # By default, CropMirrorNormalize divide each pixel by 255, to make it similar with Pytorch Loader behavior
        # in which we can control the scaler, we add additional scaler to reverse the effect
        self.normalize = normalize
        self.image_normalize = ops.CropMirrorNormalize(
            device='gpu',
            mean=[value * 255 for value in mean],
            std=[value * 255 for value in std],
            output_layout='CHW',
            image_type=types.DALIImageType.BGR)

        self.scaler = ops.Normalize(device='gpu',
                                    scale=float(255 / scaler),
                                    mean=0,
                                    stddev=1)

        # Padding and resize to prepare tensor output
        self.image_pad = ops.Paste(device='gpu',
                                   fill_value=image_pad_value,
                                   ratio=1,
                                   min_canvas_size=input_size,
                                   paste_x=0,
                                   paste_y=0)
        self.labels_pad = ops.Pad(device='cpu',
                                  axes=(0, 1),
                                  fill_value=labels_pad_value)

        self.model_input_resize = ops.Resize(
            device='gpu',
            interp_type=types.DALIInterpType.INTERP_CUBIC,
            resize_longer=input_size)
        self.peek_shape = ops.Shapes(device='gpu')
コード例 #6
0
ファイル: math_ops.py プロジェクト: seetaresearch/dragon
    def __new__(cls,
                axes=(0, 1),
                mean=None,
                stddev=None,
                scale=1.0,
                shift=0.0,
                batch=False,
                epsilon=0,
                dtype='float32',
                **kwargs):
        """Create a ``Normalize`` operator.

        Parameters
        ----------
        axes : Sequence[int], optional
            The axes to normalize.
        mean : float, optional
            The value to subtract.
        stddev : float, optional
            The value to divide after subtraction.
        scale : float, optional, default=1.0
            The scale factor after normalization.
        shift : float, optional, default=0.0
            The shift factor after normalization.
        batch : bool, optional, default=False
            Whether to compute mean and stddev across the batch.
        epsilon : float, optional, default=0
            The value added to the computed variance.
        dtype : str, optional, default='float32'
            The output data type.

        Returns
        -------
        nvidia.dali.ops.Normalize
            The operator.

        """
        if isinstance(dtype, six.string_types):
            dtype = getattr(types, dtype.upper())
        return ops.Normalize(axes=axes,
                             mean=mean,
                             stddev=stddev,
                             scale=scale,
                             shift=shift,
                             batch=batch,
                             epsilon=epsilon,
                             dtype=dtype,
                             device=context.get_device_type(),
                             **kwargs)
コード例 #7
0
ファイル: test.py プロジェクト: gddcx/yolov1
 def __init__(self, files, labels):
     super().__init__(256, 8, 0, seed=42)
     self.input = ops.FileReader(files=files,
                                 labels=labels,
                                 random_shuffle=False)
     self.decocer = ops.ImageDecoder(device='mixed', output_type=types.RGB)
     self.resize = ops.Resize(device='gpu', resize_shorter=224)
     self.crop = ops.Crop(device='gpu', crop_h=224, crop_w=224)
     mean = torch.Tensor([0.485, 0.456, 0.406
                          ]).unsqueeze(0).unsqueeze(0) * 255
     std = torch.Tensor([0.229, 0.224, 0.225
                         ]).unsqueeze(0).unsqueeze(0) * 255
     self.normalize = ops.Normalize(axes=[0, 1],
                                    mean=mean,
                                    stddev=std,
                                    device='gpu',
                                    batch=False)
     self.transpose = ops.Transpose(device='gpu', perm=[2, 0, 1])
コード例 #8
0
ファイル: debug_without_sacred.py プロジェクト: gddcx/yolov1
 def __init__(self, params, device_id, files, labels):
     super().__init__(params.batch_size,
                      params.num_gpus * 8,
                      device_id,
                      seed=params.seed)
     self.input = ops.FileReader(files=files,
                                 labels=labels,
                                 random_shuffle=False)
     self.decocer = ops.ImageDecoder(device='mixed', output_type=types.RGB)
     self.resize = ops.Resize(device='gpu', resize_shorter=224)
     self.crop = ops.Crop(device='gpu', crop_h=224, crop_w=224)
     mean = torch.Tensor(params.mean).unsqueeze(0).unsqueeze(0) * 255
     std = torch.Tensor(params.std).unsqueeze(0).unsqueeze(0) * 255
     self.normalize = ops.Normalize(axes=[0, 1],
                                    mean=mean,
                                    stddev=std,
                                    device='gpu',
                                    batch=False)
     self.transpose = ops.Transpose(device='gpu', perm=[2, 0, 1])
コード例 #9
0
    def __init__(
            self,
            *,
            train_pipeline:
        bool,  # True if train pipeline, False if validation pipeline
            device_id,
            num_threads,
            batch_size,
            file_root: str,
            file_list: str,
            sample_rate,
            discrete_resample_range: bool,
            resample_range: list,
            window_size,
            window_stride,
            nfeatures,
            nfft,
            frame_splicing_factor,
            dither_coeff,
            silence_threshold,
            preemph_coeff,
            pad_align,
            max_duration,
            mask_time_num_regions,
            mask_time_min,
            mask_time_max,
            mask_freq_num_regions,
            mask_freq_min,
            mask_freq_max,
            mask_both_num_regions,
            mask_both_min_time,
            mask_both_max_time,
            mask_both_min_freq,
            mask_both_max_freq,
            preprocessing_device="gpu"):
        super().__init__(batch_size, num_threads, device_id)

        self._dali_init_log(locals())

        if torch.distributed.is_initialized():
            shard_id = torch.distributed.get_rank()
            n_shards = torch.distributed.get_world_size()
        else:
            shard_id = 0
            n_shards = 1

        self.preprocessing_device = preprocessing_device.lower()
        assert self.preprocessing_device == "cpu" or self.preprocessing_device == "gpu", \
            "Incorrect preprocessing device. Please choose either 'cpu' or 'gpu'"
        self.frame_splicing_factor = frame_splicing_factor
        assert frame_splicing_factor == 1, "DALI doesn't support frame splicing operation"

        self.resample_range = resample_range
        self.discrete_resample_range = discrete_resample_range

        self.train = train_pipeline
        self.sample_rate = sample_rate
        self.dither_coeff = dither_coeff
        self.nfeatures = nfeatures
        self.max_duration = max_duration
        self.mask_params = {
            'time_num_regions': mask_time_num_regions,
            'time_min': mask_time_min,
            'time_max': mask_time_max,
            'freq_num_regions': mask_freq_num_regions,
            'freq_min': mask_freq_min,
            'freq_max': mask_freq_max,
            'both_num_regions': mask_both_num_regions,
            'both_min_time': mask_both_min_time,
            'both_max_time': mask_both_max_time,
            'both_min_freq': mask_both_min_freq,
            'both_max_freq': mask_both_max_freq,
        }
        self.do_remove_silence = True if silence_threshold is not None else False

        self.read = ops.FileReader(device="cpu",
                                   file_root=file_root,
                                   file_list=file_list,
                                   shard_id=shard_id,
                                   num_shards=n_shards,
                                   shuffle_after_epoch=train_pipeline)

        # TODO change ExternalSource to Uniform for new DALI release
        if discrete_resample_range and resample_range is not None:
            self.speed_perturbation_coeffs = ops.ExternalSource(
                device="cpu",
                cycle=True,
                source=self._discrete_resample_coeffs_generator)
        elif resample_range is not None:
            self.speed_perturbation_coeffs = random.Uniform(
                device="cpu", range=resample_range)
        else:
            self.speed_perturbation_coeffs = None

        self.decode = ops.AudioDecoder(
            device="cpu",
            sample_rate=self.sample_rate if resample_range is None else None,
            dtype=types.FLOAT,
            downmix=True)

        self.normal_distribution = random.Normal(device=preprocessing_device)

        self.preemph = ops.PreemphasisFilter(device=preprocessing_device,
                                             preemph_coeff=preemph_coeff)

        self.spectrogram = ops.Spectrogram(
            device=preprocessing_device,
            nfft=nfft,
            window_length=window_size * sample_rate,
            window_step=window_stride * sample_rate)

        self.mel_fbank = ops.MelFilterBank(device=preprocessing_device,
                                           sample_rate=sample_rate,
                                           nfilter=self.nfeatures,
                                           normalize=True)

        self.log_features = ops.ToDecibels(device=preprocessing_device,
                                           multiplier=np.log(10),
                                           reference=1.0,
                                           cutoff_db=math.log(1e-20))

        self.get_shape = ops.Shapes(device=preprocessing_device)

        self.normalize = ops.Normalize(device=preprocessing_device, axes=[1])

        self.pad = ops.Pad(device=preprocessing_device,
                           axes=[1],
                           fill_value=0,
                           align=pad_align)

        # Silence trimming
        self.get_nonsilent_region = ops.NonsilentRegion(
            device="cpu", cutoff_db=silence_threshold)
        self.trim_silence = ops.Slice(device="cpu",
                                      normalized_anchor=False,
                                      normalized_shape=False,
                                      axes=[0])
        self.to_float = ops.Cast(device="cpu", dtype=types.FLOAT)

        # Spectrogram masking
        self.spectrogram_cutouts = ops.ExternalSource(
            source=self._cutouts_generator, num_outputs=2, cycle=True)
        self.mask_spectrogram = ops.Erase(device=preprocessing_device,
                                          axes=[0, 1],
                                          fill_value=0,
                                          normalized_anchor=True)
コード例 #10
0
    def __init__(self, 
                 device_id, 
                 num_threads, 
                 resample_range: list,
                 sample_rate=16000, 
                 window_size=0.02, 
                 window_stride=0.01,
                 window="hann", 
                 normalize="per_feature", 
                 n_fft=None,
                 preemph=0.97,
                 nfilt=64, 
                 lowfreq=0, 
                 highfreq=0, 
                 log=True, 
                 dither=constant,
                 pad_to=8,
                 max_duration=15.0,
                 frame_splicing=3, 
                 batch_size=1, 
                 total_samples=16,
                 audio_fp16_input=True,
                 device='gpu'):
        super().__init__(batch_size, num_threads, device_id, 
                         exec_async=True, exec_pipelined=True, seed =12, prefetch_queue_depth=1)

        self._dali_init_log(locals())
        if torch.distributed.is_initialized():
            shard_id = torch.distributed.get_rank()
            n_shards = torch.distributed.get_world_size()
        else:
            shard_id = 0
            n_shards = 1

        torch_windows = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }

        self.audio_fp16_input=audio_fp16_input
        self.total_samples = total_samples
        self.win_length = int(sample_rate * window_size) # frame size
        self.hop_length = int(sample_rate * window_stride)
        self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
        
        self.normalize = normalize
        self.log = log
        self.dither = dither
        self.frame_splicing = frame_splicing
        self.nfilt = nfilt
        self.preemph = preemph
        self.pad_to = pad_to
        self.highfreq = highfreq or sample_rate / 2
        window_fn = torch_windows.get(window, None)
        window_tensor = window_fn(self.win_length,
                                  periodic=False) if window_fn else None



        self.sample_rate = sample_rate
        self.window_size = window_size
        self.window_stride = window_stride
        self.window = window_tensor

        self.lowfreq = lowfreq
        self.log = log
        self.device = device
        
        win_unpadded = self.window.tolist()
        win_padded = win_unpadded + [0] * (self.n_fft - len(win_unpadded))

        print("self.n_fft = {}".format(self.n_fft))
        print("self.hop_length = {}".format(self.hop_length))
        print("self.win_length = {}".format(self.win_length))
        print("self.window_tensor = {}".format(self.window))
        print("self.sample_rate = {}".format(self.sample_rate))
        print("self.window_size = {}".format(self.window_size))
        print("self.window_stride = {}".format(self.window_stride))
        print("self.lowfreq = {}".format(self.lowfreq))
        print("self.device = {}".format(self.device))

        self.extsrc = ops.ExternalSource(name="INPUT_0", device=self.device, no_copy=True)

        self.preemph = ops.PreemphasisFilter(preemph_coeff=preemph, device=self.device)

        self.spectrogram = ops.Spectrogram(device=self.device,
                                           nfft=self.n_fft,
                                           center_windows=True,
                                           window_fn=win_padded,
                                           window_length=len(win_padded),
                                           window_step=self.hop_length
                                           )
        self.mel_fbank = ops.MelFilterBank(device=self.device,
                                            sample_rate=self.sample_rate,
                                            nfilter=self.nfilt,
                                            freq_high=self.highfreq,
                                            freq_low=self.lowfreq,
                                            normalize=normalize
                                            )

        self.log_features = ops.ToDecibels(device=self.device, multiplier=np.log(10), reference=1.0,
                                           cutoff_db=math.log(1e-20))

        self.get_shape = ops.Shapes(device=self.device)

        self.normalize = ops.Normalize(axes=[0], device=self.device, ddof=1)

        self.pad = ops.Pad(axes=[0,1], fill_value=0, shape=[502,240], device=self.device)

        # Frame splicing
        self.splicing_transpose = ops.Transpose(device=self.device, perm=[1, 0])
        self.splicing_reshape = ops.Reshape(device=self.device, rel_shape=[-1, self.frame_splicing])
        self.splicing_pad = ops.Pad(axes=[0], fill_value=0, align=self.frame_splicing, shape=[1], device=self.device)

        self.to_float16 = ops.Cast(dtype=types.FLOAT16, device=self.device)
        self.to_float32 = ops.Cast(dtype=types.FLOAT, device=self.device)

        self.samples_done = 0
コード例 #11
0
ファイル: pipeline.py プロジェクト: graphcore/examples
    def __init__(self,
                 *,
                 pipeline_type,
                 device_id,
                 num_threads,
                 batch_size,
                 file_root: str,
                 sampler,
                 sample_rate,
                 resample_range: list,
                 window_size,
                 window_stride,
                 nfeatures,
                 nfft,
                 dither_coeff,
                 silence_threshold,
                 preemph_coeff,
                 max_duration,
                 preprocessing_device="gpu"):
        super().__init__(batch_size, num_threads, device_id)

        self._dali_init_log(locals())

        if torch.distributed.is_initialized():
            shard_id = torch.distributed.get_rank()
            n_shards = torch.distributed.get_world_size()
        else:
            shard_id = 0
            n_shards = 1

        self.preprocessing_device = preprocessing_device.lower()
        assert self.preprocessing_device == "cpu" or self.preprocessing_device == "gpu", \
            "Incorrect preprocessing device. Please choose either 'cpu' or 'gpu'"

        self.resample_range = resample_range

        train_pipeline = pipeline_type == 'train'
        self.train = train_pipeline
        self.sample_rate = sample_rate
        self.dither_coeff = dither_coeff
        self.nfeatures = nfeatures
        self.max_duration = max_duration
        self.do_remove_silence = True if silence_threshold is not None else False

        shuffle = train_pipeline and not sampler.is_sampler_random()
        self.read = ops.FileReader(name="Reader",
                                   pad_last_batch=(pipeline_type == 'val'),
                                   device="cpu",
                                   file_root=file_root,
                                   file_list=sampler.get_file_list_path(),
                                   shard_id=shard_id,
                                   num_shards=n_shards,
                                   shuffle_after_epoch=shuffle)

        # TODO change ExternalSource to Uniform for new DALI release
        if resample_range is not None:
            self.speed_perturbation_coeffs = ops.Uniform(device="cpu",
                                                         range=resample_range)
        else:
            self.speed_perturbation_coeffs = None

        self.decode = ops.AudioDecoder(
            device="cpu",
            sample_rate=self.sample_rate if resample_range is None else None,
            dtype=types.FLOAT,
            downmix=True)

        self.normal_distribution = ops.NormalDistribution(
            device=preprocessing_device)

        self.preemph = ops.PreemphasisFilter(device=preprocessing_device,
                                             preemph_coeff=preemph_coeff)

        self.spectrogram = ops.Spectrogram(
            device=preprocessing_device,
            nfft=nfft,
            window_length=window_size * sample_rate,
            window_step=window_stride * sample_rate)

        self.mel_fbank = ops.MelFilterBank(device=preprocessing_device,
                                           sample_rate=sample_rate,
                                           nfilter=self.nfeatures,
                                           normalize=True)

        self.log_features = ops.ToDecibels(device=preprocessing_device,
                                           multiplier=np.log(10),
                                           reference=1.0,
                                           cutoff_db=math.log(1e-20))

        self.get_shape = ops.Shapes(device=preprocessing_device)

        self.normalize = ops.Normalize(device=preprocessing_device, axes=[1])

        self.pad = ops.Pad(device=preprocessing_device, fill_value=0)

        # Silence trimming
        self.get_nonsilent_region = ops.NonsilentRegion(
            device="cpu", cutoff_db=silence_threshold)
        self.trim_silence = ops.Slice(device="cpu",
                                      normalized_anchor=False,
                                      normalized_shape=False,
                                      axes=[0])
        self.to_float = ops.Cast(device="cpu", dtype=types.FLOAT)