Ejemplo n.º 1
0
    def forward(self, class_one, class_two):
        batch_size, depth, height, width = class_one.size()

        class_one_flat = class_one.permute(0, 2, 3, 1).contiguous().view(
            -1, self.input_dim1)
        class_two_flat = class_two.permute(0, 2, 3, 1).contiguous().view(
            -1, self.input_dim2)

        sketch_1 = class_one_flat.mm(self.sparse_sketch_matrix1)
        sketch_2 = class_two_flat.mm(self.sparse_sketch_matrix2)

        fft1_real = afft.fft(sketch_1)
        fft1_imag = afft.fft(Variable(torch.zeros(sketch_1.size())).cuda())
        fft2_real = afft.fft(sketch_2)
        fft2_imag = afft.fft(Variable(torch.zeros(sketch_2.size())).cuda())

        fft_product_real = fft1_real.mul(fft2_real) - fft1_imag.mul(fft2_imag)
        fft_product_imag = fft1_real.mul(fft2_imag) + fft1_imag.mul(fft2_real)

        cbp_flat = afft.ifft(fft_product_real)
        #cbd_flat_2 = afft.ifft(fft_product_imag)[0]

        cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

        if self.sum_pool:
            cbp = cbp.sum(dim=1).sum(dim=1)

        return cbp.float()
Ejemplo n.º 2
0
def hilbert(x, ndft=None):
    r"""Analytic signal of x.

    Return the analytic signal of a real signal x, x + j\hat{x}, where \hat{x}
    is the Hilbert transform of x.

    Parameters
    ----------
    x: torch.Tensor
        Audio signal to be analyzed.
        Always assumes x is real, and x.shape[-1] is the signal length.

    Returns
    -------
    out: torch.Tensor
        out.shape == (*x.shape, 2)

    """
    if ndft is None:
        sig = x
    else:
        assert ndft > x.size(-1)
        sig = F.pad(x, (0, ndft - x.size(-1)))
    xspec = fft.fft(sig)
    siglen = sig.size(-1)
    hh = torch.zeros(siglen, dtype=sig.dtype, device=sig.device)
    if siglen % 2 == 0:
        hh[0] = hh[siglen // 2] = 1
        hh[1:siglen // 2] = 2
    else:
        hh[0] = 1
        hh[1:(siglen + 1) // 2] = 2

    return fft.ifft(xspec * hh)
Ejemplo n.º 3
0
 def forward(self, x):
     x = fft(x, dim=2, norm='forward').abs()[:, :,
                                             0:ModelBase.InputShape[0]]
     x = torch.relu(self.conv1(x))
     x = self.maxpool1(x)
     x = x.view(-1, self.fcInputShape[0] * self.fcInputShape[1])
     x = torch.relu(self.fc1(x))
     return x
Ejemplo n.º 4
0
 def forward(self, x):
     bsz, seq_len, embed_dim = x.shape
     self.weights = self.weights.type(x.dtype)
     w_2 = fft(torch.eye(seq_len))
     w_2 = torch.cat([w_2.real, w.imag], dim=-1).type(x.dtype)
     x = torch.matmul(x, self.weights)
     x = torch.matmul(x.transpose(1, 2), w_2).transpose(1, 2)
     return x
Ejemplo n.º 5
0
    def __init_kernel__(self):
        """
        Generate enframe_kernel, fft_kernel, ifft_kernel and overlap-add kernel.
        ** enframe_kernel: Using conv1d layer and identity matrix.
        ** fft_kernel: Using linear layer for matrix multiplication. In fact,
        enframe_kernel and fft_kernel can be combined, But for the sake of 
        readability, I took the two apart.
        ** ifft_kernel, pinv of fft_kernel.
        ** overlap-add kernel, just like enframe_kernel, but transposed.
        
        Returns:
            tuple: four kernels.
        """
        enframed_kernel = th.eye(self.fft_len)[:, None, :]
        if support_clp_op:
            tmp = fft(th.eye(self.fft_len))
            fft_kernel = th.stack([tmp.real, tmp.imag], dim=2)
        else:
            fft_kernel = fft(th.eye(self.fft_len), 1)
        if self.mode == 'break':
            enframed_kernel = th.eye(self.win_len)[:, None, :]
            fft_kernel = fft_kernel[:self.win_len]
        fft_kernel = th.cat((fft_kernel[:, :, 0], fft_kernel[:, :, 1]), dim=1)
        ifft_kernel = th.pinverse(fft_kernel)[:, None, :]
        window = get_window(self.win_type, self.win_len)

        self.perfect_reconstruct = check_COLA(window, self.win_len,
                                              self.win_len - self.win_hop)
        window = th.FloatTensor(window)
        if self.mode == 'continue':
            left_pad = (self.fft_len - self.win_len) // 2
            right_pad = left_pad + (self.fft_len - self.win_len) % 2
            window = F.pad(window, (left_pad, right_pad))
        if self.win_sqrt:
            self.padded_window = window
            window = th.sqrt(window)
        else:
            self.padded_window = window**2

        fft_kernel = fft_kernel.T * window
        ifft_kernel = ifft_kernel * window
        ola_kernel = th.eye(self.fft_len)[:self.win_len, None, :]
        if self.mode == 'continue':
            ola_kernel = th.eye(self.fft_len)[:, None, :self.fft_len]
        return enframed_kernel, fft_kernel, ifft_kernel, ola_kernel
Ejemplo n.º 6
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))

            # For 1.8+
            # pytorch now offers a complex64 data type
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            # norm=forward means no normalization
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # Plot images to confirm fft worked
            # t_img = complex_magnitude(target)
            # print(t_img.dtype, t_img.shape)
            # plt.imshow(t_img)
            # plt.show()
            # plt.imshow(target.real)
            # plt.show()

            # center crop and resize
            # target = torch.unsqueeze(target, dim=0)
            # target = center_crop(target, (128, 128))
            # target = torch.squeeze(target)

            # Crop out ends
            target = np.stack([target.real, target.imag], axis=-1)
            target = target[100:-100, 24:-24, :]

            # Downsample in image space
            shape = target.shape
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Get kspace of cropped image
            target = torch.view_as_complex(torch.from_numpy(target))
            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            # Realign kspace to keep high freq signal in center
            # Note that original fastmri code did not do this...
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
Ejemplo n.º 7
0
def _hilbert_transform(x):
    if torch.is_complex(x):
        raise ValueError("x must be real.")

    N = x.shape[-1]
    f = fft(x, N, dim=-1)
    h = torch.zeros_like(f)
    if N % 2 == 0:
        h[..., 0] = h[..., N // 2] = 1
        h[..., 1:N // 2] = 2
    else:
        h[..., 0] = 1
        h[..., 1:(N + 1) // 2] = 2

    return ifft(f * h, dim=-1)
Ejemplo n.º 8
0
    def forward_with_plt(self, x):
        x = fft(x, dim=2, norm='forward').abs()[:, :,
                                                0:ModelBase.InputShape[0]]
        x = x / x.max()

        x1 = x.detach().cpu().numpy()
        x = torch.relu(self.conv1(x))
        x2 = x.detach().cpu().numpy()
        x = self.maxpool1(x)
        x3 = x.detach().cpu().numpy()
        x = x.view(-1, self.fcInputShape[0] * self.fcInputShape[1])
        x = torch.relu(self.fc1(x))
        x4 = x.detach().cpu().numpy()
        x = torch.relu(self.fc2(x))
        x5 = x.detach().cpu().numpy()
        return x1, x2, x3, x4, x5
Ejemplo n.º 9
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # transform
            target = torch.stack([target.real, target.imag])
            target = self.deform(target)  # outputs numpy
            target = torch.from_numpy(target)
            target = target.permute(1, 2, 0).contiguous()
            # center crop and resize
            # target = center_crop(target, (128, 128))
            # target = resize(target, (128,128))

            # Crop out ends
            target = target.numpy()[100:-100, 24:-24, :]
            # Downsample in image space
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Making contiguous is necessary for complex view
            target = torch.from_numpy(target)
            target = target.contiguous()
            target = torch.view_as_complex(target)

            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
Ejemplo n.º 10
0
    dtype = dtype or torch.get_default_dtype()
    backend = dict(dtype=dtype,
                   layout=layout,
                   device=device,
                   requires_grad=requires_grad)

    if _torch_has_fftshift:
        return fft_mod.rfftfreq(n, d, **backend)

    f = torch.arange(n // 2 + 1, **backend)
    f /= (d * n)
    return f


if _torch_has_fft_module:
    fft = lambda *a, real=None, **k: fft_mod.fft(*a, **k)
else:

    def fft(input, n=None, dim=-1, norm='backward', real=None):
        """One dimensional discrete Fourier transform.

        Parameters
        ----------
        input : tensor
            Input signal.
            If torch <= 1.5, the last dimension must be of length 2 and
            contain the real and imaginary parts of the signal, unless
            `real is True`.
        n : int, optional
            Signal length. If given, the input will either be zero-padded
            or trimmed to this length before computing the FFT.
Ejemplo n.º 11
0
	def __init__(
		self,
		out_channels,
		sample_rate,
		window_size,
		window_stride,
		window,
		dither = 1e-5,
		dither0 = 0.0,
		preemphasis = 0.97,
		eps = torch.finfo(torch.float16).tiny,
		normalize_signal = True,
		debug_short_long_records_normalize_signal_multiplier = 1.0,
		stft_mode = None,
		window_periodic = True,
		normalize_features = False,
		**kwargs
	):
		super().__init__()
		self.debug_short_long_records_normalize_signal_multiplier = debug_short_long_records_normalize_signal_multiplier
		self.stft_mode = stft_mode
		self.dither = dither
		self.dither0 = dither0
		self.preemphasis = preemphasis
		self.normalize_signal = normalize_signal
		self.sample_rate = sample_rate

		self.win_length = int(window_size * sample_rate)
		self.hop_length = int(window_stride * sample_rate)
		self.nfft = 2**math.ceil(math.log2(self.win_length))
		self.freq_cutoff = self.nfft // 2 + 1

		self.register_buffer('window', getattr(torch, window)(self.win_length, periodic = window_periodic).float())
		#mel_basis = torchaudio.functional.create_fb_matrix(n_fft, n_mels = num_input_features, fmin = 0, fmax = int(sample_rate/2)).t() # when https://github.com/pytorch/audio/issues/287 is fixed
		mel_basis = torch.as_tensor(
			librosa.filters.mel(sample_rate, self.nfft, n_mels = out_channels, fmin = 0, fmax = int(sample_rate / 2))
		)
		self.mel = nn.Conv1d(mel_basis.shape[1], mel_basis.shape[0], 1).requires_grad_(False)
		self.mel.weight.copy_(mel_basis.unsqueeze(-1))
		self.mel.bias.fill_(eps)

		'''
		NOTE: Comparison of new and old fourier bases. 
		Caused by deprecation of torch.rfft in version 1.8 and above https://pytorch.org/docs/1.7.1/generated/torch.rfft.html?highlight=rfft#torch.rfft
		https://github.com/pytorch/pytorch/issues/49637#issuecomment-806532068
		
		nfft = 512
		fcutoff = 257
		
		fourier_basis = torch.rfft(torch.eye(nfft), signal_ndim=1, onesided=False)
		forward_basis = fourier_basis[:fcutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
		
		fourier_basis_new = torch.view_as_real(fft.fft(torch.eye(nfft), dim=1))
		forward_basis_new = fourier_basis_new[:fcutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis_new.shape[1])
		
		diff = forward_basis-forward_basis_new
		print('basis diff', diff.mean())
		
		assert torch.allclose(forward_basis, forward_basis_new)
		
		'''
		if stft_mode == 'conv':
			fourier_basis = torch.view_as_real(fft.fft(torch.eye(self.nfft), dim=1))
			forward_basis = fourier_basis[:self.freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
			forward_basis = forward_basis * torch.as_tensor(
				librosa.util.pad_center(self.window, self.nfft), dtype = forward_basis.dtype
			)
			self.stft = nn.Conv1d(
				forward_basis.shape[1],
				forward_basis.shape[0],
				forward_basis.shape[2],
				bias = False,
				stride = self.hop_length
			).requires_grad_(False)
			self.stft.weight.copy_(forward_basis)
		else:
			self.stft = None
Ejemplo n.º 12
0
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str,
                                                   Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """Input shape: Time x Batch x Channel

        Args:
            key_padding_mask (ByteTensor, optional): mask to exclude
                keys that are pads, of shape `(batch, src_len)`, where
                padding elements are indicated by 1s.
            need_weights (bool, optional): return the attention weights,
                averaged over heads (default: False).
            attn_mask (ByteTensor, optional): typically used to
                implement causal attention, where the mask prevents the
                attention from looking forward in time (default: None).
            before_softmax (bool, optional): return the raw attention
                weights and values before the attention softmax.
            need_head_weights (bool, optional): return the attention
                weights for each head. Implies *need_weights*. Default:
                return the average attention weights over all heads.
        """
        attn_mask = None  # TODO: to fix it, the if condition is not entered during the mian transformer architecture.
        is_tpu = query.device.type == "xla"

        if need_head_weights:
            need_weights = True

        tgt_len, bsz, embed_dim = query.size()
        src_len = tgt_len
        assert embed_dim == self.embed_dim
        assert list(query.size()) == [tgt_len, bsz, embed_dim]

        if incremental_state is not None:
            saved_state = self._get_input_buffer(incremental_state)
            if saved_state is not None and "prev_key" in saved_state:
                # previous time steps are cached - no need to recompute
                # key and value if they are static
                if static_kv:
                    assert self.encoder_decoder_attention and not self.self_attention
                    key = value = None
        else:
            saved_state = None

        k = v = prev_key = prev_value = None
        key_padding_mask = None
        if self.self_attention:
            q = self.q_proj(query)
            #k = self.k_proj(query)
            #v = self.v_proj(query)
            #assert True, f"q: {q.shape}, k: {k.shape}, v: {v.shape}"
            #q = (self.conv_q(q.permute(1, 2, 0).contiguous()).contiguous())
            #k= k.permute(1, 2, 0).contiguous()  # B * C * T
            #v = v.permute(1, 2, 0).contiguous()  # B * C * T
            #if self.shared_qkv_conv == 0:
            #    k = (self.conv_k(k).contiguous())
            #    v = (self.conv_v(v).contiguous())
            #if self.shared_qkv_conv == 1:  # use shared kv compressed linear layer
            #    k = (self.conv_q(k).contiguous())
            #    v = (self.conv_q(v).contiguous())
            #assert True, f"q: {q.shape}, k: {k.shape}, v: {v.shape}"
            assert True, f"q: {q.shape}"
        elif self.encoder_decoder_attention:
            # encoder-decoder attention
            q = self.q_proj(query)
            #q = (self.conv_q(q.permute(1, 2, 0).contiguous()).contiguous())
            #if key is None:
            #    assert value is None
            #    k = v = None
            #else:
            #    k = self.k_proj(key)
            #    v = self.v_proj(key)
            #    k= k.permute(1, 2, 0).contiguous()  # B * C * T
            #    v = v.permute(1, 2, 0).contiguous()  # B * C * T
            #    if self.shared_qkv_conv == 0:
            #        k = (self.conv_k(k).contiguous())
            #        v = (self.conv_v(v).contiguous())
            #    if self.shared_qkv_conv == 1:  # use shared kv compressed linear layer
            #        k = (self.conv_q(k).contiguous())
            #        v = (self.conv_q(v).contiguous())

        else:
            #assert key is not None and value is not None
            q = self.q_proj(query)
        # q = (self.conv_q(q.permute(1, 2, 0).contiguous()).contiguous())
        # k = self.k_proj(key)
        # v = self.v_proj(value)
        # k= k.permute(1, 2, 0).contiguous()  # B * C * T
        # v = v.permute(1, 2, 0).contiguous()  # B * C * T
        # if self.shared_qkv_conv == 0:
        #     k = (self.conv_k(k).contiguous())
        #     v = (self.conv_v(v).contiguous())
        # if self.shared_qkv_conv == 1:  # use shared kv compressed linear layer
        #     k = (self.conv_q(k).contiguous())
        #     v = (self.conv_q(v).contiguous())

        q *= self.scaling
        assert True, f"q: {q.shape}"

        #if self.bias_k is not None:
        #    assert self.bias_v is not None
        #    k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
        #    v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
        #    if attn_mask is not None:
        #        attn_mask = torch.cat(
        #            [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
        #        )
        #    if key_padding_mask is not None:
        #        key_padding_mask = torch.cat(
        #            [
        #                key_padding_mask,
        #                key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
        #            ],
        #            dim=1,
        #        )

        # the below line is important since it adjusts the vector according to num_heads
        q = (q.contiguous().view(tgt_len, bsz * self.num_heads,
                                 self.head_dim).transpose(0, 1))
        if k is not None:
            k = (k.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))
        if v is not None:
            v = (v.contiguous().view(-1, bsz * self.num_heads,
                                     self.head_dim).transpose(0, 1))

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            k = q.clone() if not k else k
            v = k
            #v = q.clone() if not v else v
            if "prev_key" in saved_state:
                _prev_key = saved_state["prev_key"]
                assert _prev_key is not None
                prev_key = _prev_key.view(bsz * self.num_heads, -1,
                                          self.head_dim)
                if static_kv:
                    k = prev_key
                else:
                    assert k is not None
                    k = torch.cat([prev_key, k], dim=1)
                src_len = k.size(1)
            if "prev_value" in saved_state:
                _prev_value = saved_state["prev_value"]
                assert _prev_value is not None
                prev_value = _prev_value.view(bsz * self.num_heads, -1,
                                              self.head_dim)
                if static_kv:
                    v = prev_value
                else:
                    assert v is not None
                    v = torch.cat([prev_value, v], dim=1)
            prev_key_padding_mask: Optional[Tensor] = None
            if "prev_key_padding_mask" in saved_state:
                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
            #assert k is not None and v is not None
            key_padding_mask = FNet._append_prev_key_padding_mask(
                key_padding_mask=key_padding_mask,
                prev_key_padding_mask=prev_key_padding_mask,
                batch_size=bsz,
                src_len=k.size(1),
                static_kv=static_kv,
            )

            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1,
                                             self.head_dim)
            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1,
                                               self.head_dim)
            saved_state["prev_key_padding_mask"] = key_padding_mask
            # In this branch incremental_state is never None
            assert incremental_state is not None
            incremental_state = self._set_input_buffer(incremental_state,
                                                       saved_state)
        assert True, f"q: {q.shape}"  # bsz * head_dim,  src_len , head_dim
        #assert k is not None
        #src_len = k.size(1)
        #assert src_len == k.size(1)
        # the above line is not exist in transformer
        #the below line exists instead
        #assert k.size(1) == src_len

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)

        #assert True, f"q: {q.shape}, k: {k.shape}, v: {v.shape}"
        #requires_convertion_to_float16= q.dtype == torch.float16
        #original_type = q.dtype
        attn_weights = fft(fft(q.type(torch.float32)),
                           axis=-2).real.type(q.dtype)
        #if requires_convertion_to_float16:
        #    attn_weights = attn_weights.type(torch.float16)
        #attn_weights = torch.bmm(q, k.transpose(1, 2))
        attn_weights = FNet.apply_sparse_mask(
            attn_weights=attn_weights,
            tgt_len=tgt_len,
            src_len=src_len,
            bsz=bsz,
        )

        assert True, f"attn_weights: {attn_weights.shape}, bsz*num_heads: {bsz * self.num_heads}, src_len: {src_len}, tgt_len: {tgt_len}"
        #assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if self.add_zero_attn:
            assert v is not None
            src_len += 1
            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])],
                          dim=1)
            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])],
                          dim=1)
            if attn_mask is not None:
                attn_mask = torch.cat(
                    [attn_mask,
                     attn_mask.new_zeros(attn_mask.size(0), 1)],
                    dim=1)
            if key_padding_mask is not None:
                key_padding_mask = torch.cat(
                    [
                        key_padding_mask,
                        torch.zeros(key_padding_mask.size(0),
                                    1).type_as(key_padding_mask),
                    ],
                    dim=1,
                )

        #attn_weights = torch.bmm(q, k.transpose(1, 2))
        #attn_weights = FNet.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)

        #assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            assert attn_weights.shape == attn_mask.shape, f"attn_weights shape {attn_weights.shape} is not the same as attn_mask shape which is {attn_mask.shape}"
            attn_weights += attn_mask

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len,
                                             src_len)
            #if not is_tpu:
            #    attn_weights = attn_weights.masked_fill(
            #        key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
            #        float("-inf"),
            #    )
            #else:
            #    attn_weights = attn_weights.transpose(0, 2)
            #    attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
            #    attn_weights = attn_weights.transpose(0, 2)
            #attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            if self.onnx_trace:
                attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
            assert attn_weights.shape == attn_mask.shape, f"attn_weights shape {attn_weights.shape} is not the same as attn_mask shape which is {attn_mask.shape}"
            attn_weights += attn_mask

        if before_softmax:
            return attn_weights, v

        attn_weights_float = utils.softmax(attn_weights,
                                           dim=-1,
                                           onnx_trace=self.onnx_trace)
        attn_weights = attn_weights_float.type_as(attn_weights)
        attn_probs = F.dropout(
            attn_weights,
            p=self.dropout,
            training=self.training,
        )
        #assert v is not None
        #attn = torch.bmm(attn_probs, v)
        attn = attn_probs
        assert list(
            attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
        if self.onnx_trace and attn.size(1) == 1:
            # when ONNX tracing a single decoder step (sequence length == 1)
            # the transpose is a no-op copy before view, thus unnecessary
            attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
        else:
            attn = attn.transpose(0,
                                  1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        attn_weights: Optional[Tensor] = None
        if need_weights:
            #attn_weights = attn_weights_float.view(
            #    bsz, self.num_heads, tgt_len, src_len
            #).transpose(1, 0)
            attn_weights = attn_weights_float.view(
                bsz, self.num_heads, tgt_len, self.head_dim).transpose(1, 0)
            if not need_head_weights:
                # average attention weights over heads
                attn_weights = attn_weights.mean(dim=0)

        return attn, attn_weights
Ejemplo n.º 13
0
def _ft_surrogate(x=None, f=None, eps=1, random_state=None):
    """FT surrogate augmentation of a single EEG channel, as proposed in [1]_.

    Function copied from https://github.com/cliffordlab/sleep-convolutions-tf
    and modified.

    MIT License

    Copyright (c) 2018 Clifford Lab

    Permission is hereby granted, free of charge, to any person obtaining a
    copy of this software and associated documentation files (the "Software"),
    to deal in the Software without restriction, including without limitation
    the rights to use, copy, modify, merge, publish, distribute, sublicense,
    and/or sell copies of the Software, and to permit persons to whom the
    Software is furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in
    all copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
    DEALINGS IN THE SOFTWARE.

    Parameters
    ----------
    x: torch.tensor, optional
        Single EEG channel signal in time space. Should not be passed if f is
        given. Defaults to None.
    f: torch.tensor, optional
        Fourier spectrum of a single EEG channel signal. Should not be passed
        if x is given. Defaults to None.
    eps: float, optional
        Float between 0 and 1 setting the range over which the phase
        pertubation is uniformly sampled: [0, `eps` * 2 * `pi`]. Defaults to 1.
    random_state: int | numpy.random.Generator, optional
        By default None.

    References
    ----------
    .. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
       Clifford, G. D. (2018). Addressing Class Imbalance in Classification
       Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
       preprint arXiv:1806.08675.
    """
    assert isinstance(
        eps,
        (Real, torch.FloatTensor, torch.cuda.FloatTensor)
    ) and 0 <= eps <= 1, f"eps must be a float beween 0 and 1. Got {eps}."
    if f is None:
        assert x is not None, 'Neither x nor f provided.'
        f = fft(x.double(), dim=-1)
        device = x.device
    else:
        device = f.device
    n = f.shape[-1]
    random_phase = _new_random_fft_phase[n % 2](
        n,
        device=device,
        random_state=random_state
    )
    f_shifted = f * torch.exp(eps * random_phase)
    shifted = ifft(f_shifted, dim=-1)
    return shifted.real.float()
Ejemplo n.º 14
0
 def preprocess(self, X):
     Z = F.normalize(X)
     return fft(Z, norm='ortho', dim=2)
Ejemplo n.º 15
0
def morlet_fft_convolution(x, log_scales, dtime, unpadding=0, density=True, omega0=6.0):
    """
    Calculates a Morlet continuous wavelet transform
    for a given signal across a range of frequencies

    Parameters:
    ===========
    x : torch.Tensor, shape (batch, channels, sequence)
        A batch of multichannel sequences
    log_scales : torch.Tensor, shape (1, 1, freqs, 1)
        A tensor of logarithmic scales
    dtime : float
        Change in time per sample. The inverse of the sampling frequency.
    unpadding : int
        The amount of padding to remove from each side of the sequence
    density : bool (default = True)
        Whether to normalize so the power spectrum sums to one.
        This effectively removes amplitude fluctuations.
    omega0 : float
        Dimensionless omega0 parameter for wavelet transform
    Returns:
    ========
    out : torch.Tensor, shape (batch, channels, freqs, sequence)
        The transformed signal.
    """

    n_sequence = x.shape[-1]

    # Pad with extra zero if needed
    pad_sequence = n_sequence % 2 != 0
    x = F.pad(x, (0, 1)) if pad_sequence else x

    # Set index to remove the extra zero if added
    idx0 = (n_sequence // 2) + unpadding
    idx1 = (
        (n_sequence // 2) + n_sequence - unpadding - 1
        if pad_sequence
        else (n_sequence // 2) + n_sequence - unpadding
    )
    x = F.pad(x, (n_sequence // 2, n_sequence // 2))

    # (batch, channels, sequence) -> (batch, channels, freqs, sequence)
    x = x.unsqueeze(-2)

    # Calculate the omega values
    n_padded = x.shape[-1]
    omegas = (
        -2
        * np.pi
        * torch.arange(start=-n_padded // 2, end=n_padded // 2, device=x.device)
        / (n_padded * dtime)
    )[
        None, None, None
    ]  # (sequence,) -> (batch, channels, freqs, sequence)

    # Fourier transform the padded signal
    x_hat = fftshift1d(fft.fft(x, dim=-1))

    # Calculate the wavelets
    morlet = morlet_conj_ft(omegas * log_scales.exp(), omega0)

    # Perform the wavelet transform
    convolved = (
        fft.ifft(morlet * x_hat, dim=-1)[..., idx0:idx1] * log_scales.mul(0.5).exp()
    )

    power = convolved.abs()

    # scale power to account for disproportionally
    # large response at low frequencies
    # power_scale = (
    #    np.pi ** -0.25
    #    * np.exp(0.25 * (omega0 - np.sqrt(omega0 ** 2 + 2)) ** 2)
    #    / scales.mul(2).sqrt()
    # )
    log_power_scale = (
        -0.25 * LOG_PI
        + 0.25 * (omega0 - np.sqrt(omega0 ** 2 + 2)) ** 2
        - log_scales.add(LOG2).mul(0.5)
    )
    log_power_scaled = power.log() + log_power_scale
    log_total_power = log_power_scaled.logsumexp(
        (1, 2), keepdims=True
    )  # (channels, freqs)
    log_density = log_power_scaled - log_total_power
    return log_density.exp()
Ejemplo n.º 16
0
    Tp = 2.5e-06
    Br = abs(Kr) * Tp

    alpha = 1.24588  # 1.1-1.4
    Fsr = alpha * Br
    # Fc = 5.3e9
    Fc = 0.

    Tsr = 1.2 * Tp
    Nsr = int(Fsr * Tsr)
    t = th.linspace(-Tsr / 2., Tsr / 2, Nsr)
    f = th.linspace(-Fsr / 2., Fsr / 2, Nsr)

    St = sar_tran(t, Tp, Kr, Fc)

    Yt = fftshift(fft(fftshift(St, dim=0), dim=0), dim=0)

    plt.figure(1)
    plt.subplot(221)
    plt.plot(t * 1e6, th.real(St))
    plt.plot(t * 1e6, th.abs(St))
    plt.grid()
    plt.legend({'Real part', 'Amplitude'})
    plt.title('Matched filter')
    plt.xlabel('Time/μs')
    plt.ylabel('Amplitude')
    plt.subplot(222)
    plt.plot(t * 1e6, th.angle(St))
    plt.grid()
    plt.subplot(223)
    plt.plot(f, th.abs(Yt))
Ejemplo n.º 17
0
    plt.legend(['real', 'imag'])
    plt.subplot(222)
    plt.plot(t * 1e6, th.angle(St))
    plt.xlabel('Time/us')
    plt.subplot(223)
    plt.plot(t * 1e6, th.real(Sr))
    plt.plot(t * 1e6, th.imag(Sr))
    plt.xlabel('Time/us')
    plt.legend(['real', 'imag'])
    plt.subplot(224)
    plt.plot(t * 1e6, th.angle(Sr))
    plt.xlabel('Time/us')
    plt.show()

    # ---Frequency domain
    Yt = fftshift(fft(fftshift(St, dim=0), dim=0), dim=0)
    Yr = fftshift(fft(fftshift(Sr, dim=0), dim=0), dim=0)

    # ---Plot signals
    plt.figure(figsize=(10, 8))
    plt.subplot(221)
    plt.plot(t * 1e6, th.real(St))
    plt.grid()
    plt.title('Real part')
    plt.xlabel('Time/μs')
    plt.ylabel('Amplitude')
    plt.subplot(222)
    plt.plot(t * 1e6, th.imag(St))
    plt.grid()
    plt.title('Imaginary part')
    plt.xlabel('Time/μs')
Ejemplo n.º 18
0
 def rfft(input, signal_ndim, normalized=False):
     norm = "ortho" if normalized else "backward"
     return fft.fft(input, dim=-1, norm=norm)
def cwt(
    data: torch.Tensor,
    scales: Union[np.ndarray, torch.Tensor],  # type: ignore
    wavelet: Union[ContinuousWavelet, str],
    sampling_period: float = 1.0,
) -> Tuple[torch.Tensor, np.ndarray]:  # type: ignore
    """Compute the single dimensional continuous wavelet transform.

    This function is a PyTorch port of pywt.cwt as found at:
    https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py

    Args:
        data (torch.Tensor): The input tensor of shape [batch_size, time].
        scales (torch.Tensor or np.array):
            The wavelet scales to use. One can use
            ``f = pywt.scale2frequency(wavelet, scale)/sampling_period`` to determine
            what physical frequency, ``f``. Here, ``f`` is in hertz when the
            ``sampling_period`` is given in seconds.
            wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.
        wavelet (ContinuousWavelet or str): The continuous wavelet to work with.
        sampling_period (float): Sampling period for the frequencies output (optional).
            The values computed for ``coefs`` are independent of the choice of
            ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
            period).

    Raises:
        ValueError: If a scale is too small for the input signal.

    Returns:
        Tuple[torch.Tensor, np.ndarray]: A tuple with the transformation matrix
            and frequencies in this order.
    """
    # accept array_like input; make a copy to ensure a contiguous array
    if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
        wavelet = DiscreteContinuousWavelet(wavelet)
    if type(scales) is torch.Tensor:
        scales = scales.numpy()
    elif np.isscalar(scales):
        scales = np.array([scales])
    # if not np.isscalar(axis):
    #    raise np.AxisError("axis must be a scalar.")

    precision = 10
    int_psi, x = integrate_wavelet(wavelet, precision=precision)
    if type(wavelet) is ContinuousWavelet:
        int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
    int_psi = torch.tensor(int_psi, device=data.device)

    # convert int_psi, x to the same precision as the data
    x = np.asarray(x, dtype=data.cpu().numpy().real.dtype)

    size_scale0 = -1
    fft_data = None

    out = []
    for scale in scales:
        step = x[1] - x[0]
        j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
        j = j.astype(int)  # floor
        if j[-1] >= len(int_psi):
            j = np.extract(j < len(int_psi), j)
        int_psi_scale = int_psi[j].flip(0)

        # The padding is selected for:
        # - optimal FFT complexity
        # - to be larger than the two signals length to avoid circular
        #   convolution
        size_scale = _next_fast_len(data.shape[-1] + len(int_psi_scale) - 1)
        if size_scale != size_scale0:
            # Must recompute fft_data when the padding size changes.
            fft_data = fft(data, size_scale, dim=-1)
        size_scale0 = size_scale
        fft_wav = fft(int_psi_scale, size_scale, dim=-1)
        conv = ifft(fft_wav * fft_data, dim=-1)
        conv = conv[..., :data.shape[-1] + len(int_psi_scale) - 1]

        coef = -np.sqrt(scale) * torch.diff(conv, dim=-1)

        # transform axis is always -1
        d = (coef.shape[-1] - data.shape[-1]) / 2.0
        if d > 0:
            coef = coef[..., int(np.floor(d)):-int(np.ceil(d))]
        elif d < 0:
            raise ValueError("Selected scale of {} too small.".format(scale))

        out.append(coef)
    out_tensor = torch.stack(out)
    if type(wavelet) is Wavelet:
        out_tensor = out_tensor.real
    else:
        out_tensor = out_tensor if wavelet.complex_cwt else out_tensor.real

    frequencies = scale2frequency(wavelet, scales, precision)
    if np.isscalar(frequencies):
        frequencies = np.array([frequencies])
    frequencies /= sampling_period
    return out_tensor, frequencies
Ejemplo n.º 20
0
def test(root_dir='assets\\nsynth_test',
         model_class=Model1,
         dataset_class=NSynthDataSet,
         memoryLimitInMB=1024,
         device=torch.device('cpu'),
         load_path='model.pth',
         plotOnly=True,
         windowStep=4000):

    # Calculate chunk size based on the memory limit allowed
    windowsPerWav = int(NSynthDataSet.WavFileTime *
                        NSynthDataSet.SamplingFrequency // windowStep)

    if dataset_class == NSynthChunkedDataSet:
        chunkSize = ceil(
            int(memoryLimitInMB * 1024 * 1024 /
                NSynthDataSet.FrequencyBinsCount / 4) /
            windowsPerWav) * windowsPerWav
        data_set = NSynthChunkedDataSet(root_dir=root_dir,
                                        chunkSize=chunkSize,
                                        model=model_class,
                                        windowStep=windowStep,
                                        device=device)
    else:
        data_set = dataset_class(root_dir=root_dir,
                                 model=model_class,
                                 windowStep=windowStep)

    # Initialize the sampler to split between train and validation sets
    allChunkIndices = list(range(0, ceil(len(data_set) / chunkSize)))
    lastChunkSize = len(data_set) % chunkSize
    if (lastChunkSize == 0): lastChunkSize = chunkSize

    testWavRandomSampler = WavRandomSampler(chunkSize, allChunkIndices,
                                            allChunkIndices[-1], lastChunkSize)

    testDataLoader = DataLoader(
        data_set,
        shuffle=False,
        num_workers=0,
        batch_size=
        None,  # Specially needed - else the auto_collation makes batch sampling useless!
        sampler=BatchSampler(testWavRandomSampler,
                             batch_size=1,
                             drop_last=False))

    # Plotter object to plot the losses for each window
    plotter = VisdomLinePlotter(env_name='Accuracy plot')

    # Port the model to device
    model = model_class()
    model.load_state_dict(torch.load(load_path))
    model = model.to(device)

    model.eval()
    frequencyBinRange = torch.from_numpy(
        (NSynthDataSet.SamplingFrequency / 2.0 / data_set.outputBinsCount) *
        np.array(range(data_set.outputBinsCount))).to(device)
    for batch_idx, batch in enumerate(testDataLoader):
        # Get the inputs from the dataset
        inputs, labels, labelIdx = batch
        fftInputs = fft(inputs, dim=2,
                        norm='forward').abs()[:, :, 0:ModelBase.InputShape[0]]
        fftInputs = fftInputs / fftInputs.max()

        print("Wav file =", data_set.labelVector[labelIdx])

        # forward
        if plotOnly:
            x1, x2, x3, x4, x5 = model.forward_with_plt(inputs)
            #print ('x1[', x1.shape, '] =', x1, '\n',
            #       'x2[', x2.shape, '] =', x2, '\n',
            #       'x3[', x3.shape, '] =', x3, '\n',
            #       'x4[', x4.shape, '] =', x4, '\n',
            #       'x5[', x5.shape, '] =', x5, '\n')

            fftInputs = fftInputs.detach().cpu().numpy()
            #outputs = outputs.detach().cpu().numpy()
            outputs = None
            expected = labels.detach().cpu().numpy()

            #print(outputs,expected)
            fig, axs = plt.subplots(6)
            axs[0].plot(expected[0, :], color='black')
            axs[1].plot(x1[0, 0, :], color='red')
            axs[2].plot(x2[0, 0, :], color='green')
            axs[3].plot(x3[0, 0, :], color='blue')
            axs[4].plot(x4[0, :], color='magenta')
            axs[5].plot(x5[0, :], color='yellow')
            plt.show()
        else:
            frequencyWeightedSum = torch.sum((outputs * frequencyBinRange),
                                             axis=1)
            outputSum = torch.sum(outputs, axis=1)
            observed = (frequencyWeightedSum /
                        outputSum).detach().cpu().numpy()[0]
            plotter.plot('Hz', 'Observed', 'Frequencies', batch_idx, observed)

            expectedFrequencyWeightedSum = torch.sum(
                (labels * frequencyBinRange), axis=1)
            expectedSum = torch.sum(labels, axis=1)
            expected = (expectedFrequencyWeightedSum /
                        expectedSum).detach().cpu().numpy()[0]
            plotter.plot('Hz', 'Expected', 'Frequencies', batch_idx, expected)
Ejemplo n.º 21
0
 def __init__(self, embed_dict, embed_dim):
     self.weights = fft(torch.eye(embed_dim))
     self.weights = torch.cat([self.weights.real, self.weights.imag],
                              dim=-1).type(torch.float16)
Ejemplo n.º 22
0
def fft(x, n=None, axis=0, norm="backward", shift=False):
    """FFT in torchsar

    FFT in torchsar.

    Parameters
    ----------
    x : {torch array}
        complex representation is supported. Since torch1.7 and above support complex array,
        when :attr:`x` is in real-representation formation(last dimension is 2, real, imag),
        we will change the representation in complex formation, after FFT, it will be change back.
    n : int, optional
        number of fft points (the default is None --> equals to signal dimension)
    axis : int, optional
        axis of fft (the default is 0, which the first dimension)
    norm : {None or str}, optional
        Normalization mode. For the forward transform (fft()), these correspond to:
        - "forward" - normalize by ``1/n``
        - "backward" - no normalization (default)
        - "ortho" - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
    shift : bool, optional
        shift the zero frequency to center (the default is False)

    Returns
    -------
    y : {torch array}
        fft results torch array with the same type as :attr:`x`

    Raises
    ------
    ValueError
        nfft is small than signal dimension.
    """

    if norm is None:
        norm = 'backward'

    if (x.size(-1) == 2) and (not th.is_complex(x)):
        realflag = True
        x = th.view_as_complex(x)
        if axis < 0:
            axis += 1
    else:
        realflag = False

    d = x.size(axis)
    if n is None:
        n = d
    if d < n:
        x = padfft(x, n, axis, shift)
    elif d > n:
        raise ValueError('nfft is small than signal dimension!')

    if shift:
        y = thfft.fftshift(thfft.fft(thfft.fftshift(x, dim=axis),
                                     n=n,
                                     dim=axis,
                                     norm=norm),
                           dim=axis)
    else:
        y = thfft.fft(x, n=n, dim=axis, norm=norm)

    if realflag:
        y = th.view_as_real(y)

    return y
Ejemplo n.º 23
0
 def rfft(x):
     return fft.fft(x,
                    n=n_fft,
                    dim=-2,
                    norm='ortho' if normalized else 'backward')
Ejemplo n.º 24
0
    def split_step_solver(self, a, T, d, c, L):
        '''
        Parameters
        ----------
        a : TYPE: torch.int64 tensor of shape [batch_size, seq_len], optional
            DESCRIPTION: Pulse amplitude set.
        T : TYPE: float
            DESCRIPTION: Pulse width.
        d : TYPE: float
            DESCRIPTION: dispersion coefficient.
        c : TYPE: float
            DESCRIPTION: nonlinearity coefficient.
        L : TYPE: float
            DESCRIPTION: End of transmission line (z_end).

        Returns
        -------
        t : TYPE: torch.float32 tensor of shape [dim_t]
            DESCRIPTION: Time points. See prepare_data description.
        z : TYPE: torch.float32 tensor of shape [dim_z]
            DESCRIPTION: Points in space. See prepare_data description.
        u : TYPE: torch.complex128 tensor of shape [batch_size, dim_z, dim_t].
            DESCRIPTION: Output of the split-step solution.
        '''
        z = torch.linspace(0, L, self.dim_z, device=self.device)

        tMax = self.t_end + 4 * sqrt(2 * (1 + L**2))
        tMin = -tMax

        dt = (tMax - tMin) / self.dim_t
        t = torch.linspace(tMin, tMax - dt, self.dim_t, device=self.device)

        # prepare frequencies
        dw = 2 * pi / (tMax - tMin)
        w = dw * torch.cat(
            (torch.arange(0, self.dim_t / 2 + 1, device=self.device),
             torch.arange(-self.dim_t / 2 + 1, 0, device=self.device)))

        # prepare linear propagator
        LP = torch.exp(-1j * d * self.dz / 2 * w**2)

        # Set initial condition
        u = torch.zeros(a.shape[0],
                        self.dim_z,
                        self.dim_t,
                        dtype=self.complex_type,
                        device=self.device)

        buf = self.Etanal(t, torch.tensor(0, device=self.device), a, T)
        u[:, 0, :] = buf

        n = 0
        # Numerical integration (split-step)
        for i in range(1, int(L / self.dz) + 1):
            buf = ifft(LP * fft(buf))
            buf = buf * torch.exp(1j * c * self.dz * buf.abs()**2)
            buf = ifft(LP * fft(buf))

            if i % self.z_stride == 0:
                n += 1
                u[:, n, :] = buf

        # Dispersion compensation procedure (back propagation D**(-1))
        if self.dispersion_compensate:
            zw = torch.mm(z.view(z.shape[-1], 1), w.view(1, w.shape[-1])**2)
            u = ifft(torch.exp(1j * d * zw) * fft(u))

        return t, z, u