def batch_fft(data, normalize=False):
    """
    Compute fourier transform of batch.
    Args:
        data: input tensor, (NxHxW)

    Returns:
    Batch fourier transform of input data.
    """

    dim = data.ndim - 1  # subtract one for batch dimension
    if dim != 2:
        raise AttributeError(f'Data must be 2d but it is {dim}d.')

    dims = tuple(range(1, dim + 1))  # add one for batch dimension
    if normalize:
        norm = 'ortho'
    else:
        norm = 'backward'

    if not torch.is_complex(data):
        data = torch.complex(data, torch.zeros_like(data))
    freq = fftn(data, dim=dims, norm=norm)

    return freq
def kspace_downsample_patch(hr_patch: torch.Tensor,
                            center_crop: int = 25,
                            end_crop: int = 6) -> torch.Tensor:
    """
  Down-sample high resolution patch by k-space truncation.

  Note: end_crop worsens the picture quality much more than center_crop.

  Args:
    hr_patch: original high resolution patch
    center_crop: square dimension of center to be removed
    end_crop: final rows/cols to be removed
  Returns:
    lr_patch in shape (channel, patch_dim, patch_dim)
  """
    lr_patch = fftn(hr_patch)

    # remove last n cols and rows
    if end_crop > 0:
        lr_patch[:, -end_crop:, :] = 0
        lr_patch[:, :, -end_crop:] = 0
    # remove square center
    if center_crop > 0:
        i = max(center_crop // 2, 1)
        x_center = lr_patch.shape[1] // 2
        y_center = lr_patch.shape[2] // 2
        lr_patch[:, x_center - i:x_center + i, y_center - i:y_center + i] = 0

    return torch.abs(ifftn(lr_patch))
Exemple #3
0
 def apply(self, x: Tensor) -> Tensor:
     x_ = x.unflatten(dim=-1, sizes=self.size)
     y_ = fftn(x_, dim=(-2, -1), norm='ortho')
     y_ = y_.flatten(start_dim=-2)
     out_size = y_.size()[:-1] + (self.K,)
     y = torch.gather(y_, dim=-1, index=self.index.expand(out_size))
     return y
Exemple #4
0
def fft(input, inverse=False):
    """
        Interface with torch FFT routines for 2D signals.

        Example
        -------
        x = torch.randn(128, 32, 32, 2)
        x_fft = fft(x, inverse=True)

        Parameters
        ----------
        input : tensor
            complex input for the FFT
        inverse : bool
            True for computing the inverse FFT.
            NB : if direction is equal to 'C2R', then the transform
            is automatically inverse.
    """

    if not iscomplex(input):
        raise(TypeError('The input should be complex (e.g. last dimension is 2)'))

    if (not input.is_contiguous()):
        raise (RuntimeError('Tensors must be contiguous!'))

    if inverse:
        output = ifftn(input[..., 0] + 1j*input[..., 1], s=(-1, -1))
        #output = torch.ifft(input, 2, normalized=False)
        #output = torch.fft.ifft(input, 2, norm= "forward")
    else:
        output = fftn(input[..., 0] + 1j*input[..., 1], s=(-1, -1))
        #output = torch.fft(input, 2, normalized=False)
        #output = torch.fft.fft(input, 2, norm= "forward")
    output = torch.stack((output.real, output.imag), dim=-1)
    return output
Exemple #5
0
    def generate_single_frame(self, output_num, input_image=None):
        if not os.path.exists(self.output_folder):
            os.makedirs(self.output_folder)
        padding_size = 20
        _, _, fx, fy = self.GridGenerate(self.image_size + 2 * padding_size,
                                         grid_mode='real')
        f_grid = pow((fx**2 + fy**2),
                     1 / 2)  # The spatial freqneucy fr=sqrt( fx^2 + fy^2 )

        OTF_padding = self.OTF_form(f_grid)
        random_distribution = torch.rand([self.image_size, self.image_size])
        if input_image == None:
            input_image = torch.ones_like(random_distribution)
        random_distribution = random_distribution * input_image
        threhold = self.fluorophore_density
        fluorophore_num = round(self.fluorophore_density * self.image_size**2)
        fluorophore_loc = random_distribution < threhold

        while fluorophore_loc.sum() < fluorophore_num:
            threhold *= 2
            fluorophore_loc = random_distribution < threhold
        fluorophore_GT = torch.zeros_like(random_distribution)
        fluorophore_GT[fluorophore_loc] = 1
        ZeroPad_operation = nn.ZeroPad2d(20)
        fluorophore_GT_padding = ZeroPad_operation(
            fluorophore_GT)  # 直接进行频域OTF滤波,会有边缘信息的串扰,用zero_padding方法去除
        fluorophore_padding_diffractive_spectrum = torch_2d_fftshift(
            fft.fftn(fluorophore_GT_padding, dim=[0, 1])) * OTF_padding
        fluorophore_padding_diffractive_image = abs(
            fft.ifftn(
                torch_2d_ifftshift(fluorophore_padding_diffractive_spectrum),
                dim=[1, 0]))
        fluorophore_diffractive_image = fluorophore_padding_diffractive_image[
            padding_size:-padding_size, padding_size:-padding_size]
        # print(fluorophore_loc.sum())
        # common_utils.plot_single_tensor_image(fluorophore_diffractive_image)
        # image_size_real = self.image_size / self.downsample_rate
        AvgPool_operation = nn.AvgPool2d(kernel_size=self.downsample_rate,
                                         stride=self.downsample_rate)
        fluorophore_image_in_camera = AvgPool_operation(
            fluorophore_diffractive_image.unsqueeze(0).unsqueeze(0)).squeeze()
        # np.save(os.path.join(self.output_folder, output_num + '_label'), fluorophore_GT.numpy())

        fluorophore_GT_loc = (fluorophore_GT.numpy() == 1)
        fluorophore_GT_loc_xy = np.where(fluorophore_GT_loc)

        x = fluorophore_GT_loc_xy[0].astype(np.int32)
        y = fluorophore_GT_loc_xy[1].astype(np.int32)

        label_file_dir = os.path.join(self.output_folder,
                                      output_num + '_label.txt')
        label_file = open(label_file_dir, 'w')
        for i in range(len(x)):
            label_file.write('{} {}\n'.format(
                x[i], y[i]))  # todo there must be vector way
        label_file.close()

        return fluorophore_image_in_camera
Exemple #6
0
 def forward(self, add, model, data):
     self.checkDomainRange(model, data)
     if not add:
         data.zero()
     data[:] += fft.fftn(model.getNdArray(),
                         s=self.nfft,
                         dim=self.axes,
                         norm='ortho')
     return
Exemple #7
0
    def forward(self, inputs):
        suscp = inputs[0]
        kernel = inputs[1]

        ks = fft.fftn(suscp, dim=[-3, -2, -1])

        ks = ks * kernel
        fm = torch.real(fft.ifftn(ks, dim=[-3, -2, -1]))

        return fm
Exemple #8
0
 def __init__(self, D, m, b, wG, device, lambda_TV, P=1, alpha=0.5, rho=10):
     self.D = D
     self.m = m
     self.b = b
     self.wG = wG
     self.device = device
     self.lambda_TV = lambda_TV
     self.P = P
     self.alpha = alpha
     self.rho = rho
     self.Dconv = lambda x: torch.real(fft.ifftn(self.D * fft.fftn(x, dim=[0, 1, 2])))
Exemple #9
0
    def __call__(self, vis, u, v):
        input_grid = self.grid_2d(vis, u, v)
        input_grid = fftshift(input_grid, axes=None)
        out = fftn(input_grid)
        out = fftshift(out)

        alpha = self.config['alpha']
        xl = int(0.5 * self.nx * (alpha - 1))
        yl = int(0.5 * self.nx * (alpha - 1))

        out = out[xl:xl + self.nx, yl:yl + self.ny]

        return out / self.gc
 def closure():
     optimizer.zero_grad()
     outputs = resnet(inputs_cat)
     outputs_cplx = outputs.type(torch.complex64)
     # loss
     RDFs_outputs = torch.real(
         fft.ifftn((fft.fftn(outputs_cplx, dim=[2, 3, 4]) * D),
                   dim=[2, 3, 4]))
     diff = torch.abs(rdfs - RDFs_outputs)
     loss_fidelity = (1 - alpha) * 0.5 * torch.sum(
         (weights * diff)**2)
     loss_l2 = rho * 0.5 * torch.sum(
         (x - outputs[0, 0, ...] + mu)**2)
     loss = loss_fidelity + loss_l2
     # loss = loss_fidelity
     loss.backward()
     return loss
    def forward(self, x, up_feat_in):
        # separate feature for two frequency
        freq_x = fft.fftn(x, dim=(-2, -1))
        freq_shift = fft.fftshift(freq_x, dim=(-2, -1))

        # low_freq_shift = self.easy_low_pass_filter(freq_x)
        # high_freq_shift = self.easy_high_pass_filter(freq_x)
        low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter(
            freq_shift)

        low_freq_ishift = fft.ifftshift(low_freq_shift, dim=(-2, -1))
        high_freq_ishift = fft.ifftshift(high_freq_shift, dim=(-2, -1))

        _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift, dim=(-2, -1)))
        _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift, dim=(-2, -1)))

        low_freq_x = self.low_project(_low_freq_x)
        high_freq_x = self.high_project(_high_freq_x)

        feat = torch.cat([x, low_freq_x, high_freq_x], dim=1)
        context = self.out_project(feat)
        fuse_feature = context + x  # Whether use skip connection or not

        if self.up_flag and self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            up_feature = self.up(fuse_feature)
            smooth_feature = self.smooth(fuse_feature)
            return up_feature, smooth_feature

        if self.up_flag and not self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            up_feature = self.up(fuse_feature)
            return up_feature

        if not self.up_flag and self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            smooth_feature = self.smooth(fuse_feature)
            return smooth_feature
    def forward(self, x):
        """Performs a forward pass over the data.
        Args:
            x (torch.Tensor): An input tensor for computing the forward pass.
        Returns:
            A tensor containing the DBN's outputs.
       
        """

        #self.p = 0
        frames = x.size(1) #frames            
        dy, dx = x.size(2), x.size(3)
        ds = torch.zeros((x.size(0), frames, self.n_hidden))

        # Checking whether GPU is avaliable and if it should be used
        if self.device == 'cuda':
            # Applies the GPU usage to the data            
            x = x.cuda()
            ds = ds.cuda()

        for fr in range(frames):
            sps = x[:, fr, :, :].squeeze()

            # Creating the Fourier Spectrum
            spec_data = fftshift(fftn(sps))[:,:,:,0]
            spec_data = torch.abs(spec_data.squeeze())
            
            # Flattening the samples' batch
            spec_data = spec_data.reshape(spec_data.size(0), self.n_visible)
        
            # Normalizing the samples' batch
            spec_data = ((spec_data - torch.mean(spec_data, 0, True)) / (torch.std(spec_data, 0, True) + c.EPSILON)).detach()

            spec_data, _ = self.hidden_sampling(spec_data)
            ds[:, fr, :] = spec_data.reshape((spec_data.size(0), self.n_hidden))

        x.detach()
        sps.detach()

        return ds.detach()
    def forward(self, x):
        # self.writer = writer
        freq_x = fft.fftn(x)
        freq_shift = fft.fftshift(freq_x)
        
        # low_freq_shift = self.easy_low_pass_filter(freq_x)
        # high_freq_shift = self.easy_high_pass_filter(freq_x)
        low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter(freq_shift)

        # low_freq_ishift = fft.ifftshift(low_freq_shift)
        high_freq_ishift = fft.ifftshift(high_freq_shift)
        
        # _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift))
        _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift))

        feat_rgb = self.sp(_high_freq_x)
        feat_dct = self.cp(x)
        feat_fuse = torch.cat((feat_rgb, feat_dct), dim=1)
        logits = self.head(feat_fuse)
        out = F.interpolate(logits, scale_factor=self.block_size, mode='bilinear', \
            align_corners=True)
        return out
Exemple #14
0
    def batch_image_OTF_filter(self, batch_image):
        batch_image = batch_image.squeeze()
        ZeroPad_operation = nn.ZeroPad2d(20)
        padding_size = 20
        _, _, fx, fy = self.GridGenerate(batch_image.shape[-1] +
                                         2 * padding_size,
                                         grid_mode='real')

        OTF_padding = self.OTF_padding.to(batch_image.device)
        batch_image_padding = ZeroPad_operation(
            batch_image)  # 直接进行频域OTF滤波,会有边缘信息的串扰,用zero_padding方法去除
        batch_image_padding_diffractive_spectrum = torch_2d_fftshift(
            fft.fftn(batch_image_padding, dim=[1, 2
                                               ])) * OTF_padding.unsqueeze(0)
        batch_image_padding_diffractive = abs(
            fft.ifftn(
                torch_2d_ifftshift(batch_image_padding_diffractive_spectrum),
                dim=[2, 1]))
        batch_image_diffractive = batch_image_padding_diffractive[:,
                                                                  padding_size:
                                                                  -padding_size,
                                                                  padding_size:
                                                                  -padding_size]
        return batch_image_diffractive
Exemple #15
0
    def generate_frame_batch(self):
        xx, yy, _, _ = self.GridGenerate(grid_mode='real')

        OTF = self.OTF
        random_distribution = torch.rand(
            [self.parallel_frames, self.image_size, self.image_size])
        fluorophore_loc = random_distribution < self.fluorophore_density
        fluorophore_GT = torch.zeros_like(random_distribution)
        fluorophore_GT[fluorophore_loc] = 1

        fluorophore_diffractive_spectrum = torch_2d_fftshift(
            fft.fftn(fluorophore_GT, dim=[1, 2])) * OTF.unsqueeze(0)
        fluorophore_diffractive_image = abs(
            fft.ifftn(torch_2d_ifftshift(fluorophore_diffractive_spectrum),
                      dim=[2, 1]))
        common_utils.plot_single_tensor_image(
            fluorophore_diffractive_image[0, :, :])
        # image_size_real = self.image_size / self.downsample_rate
        AvgPool_operation = nn.AvgPool2d(kernel_size=self.downsample_rate,
                                         stride=self.downsample_rate)
        fluorophore_image_in_camera = AvgPool_operation(
            fluorophore_diffractive_image.unsqueeze(0)).squeeze()

        return fluorophore_GT, fluorophore_image_in_camera
def _fft_convnd(input: Tensor, weight: Tensor, bias: Optional[Tensor],
                stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int],
                groups: int) -> Tensor:

    output_size = _conv_shape(input.shape[2:], weight.shape[2:], stride,
                              padding, dilation)
    reversed_padding_repeated_twice = _reverse_repeat_tuple(padding, 2)
    padded_input = F.pad(input, reversed_padding_repeated_twice)

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_input.shape[2:], weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size)
        weight_s.append(s_size // d)

    X = rfftn(padded_input, s=s)

    W = rfft(weight, n=weight_s[-1])
    # handle dilation
    # handle dilation for last dim
    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(weight_s) > 1:
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + dilation[:-1] + (1, )
        W.imag.mul_(-1)
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)
    else:
        W.imag.mul_(-1)

    Y = _complex_matmul(X, W, groups)

    # handle stride
    if len(stride) > 1:
        for i, st in enumerate(stride[:-1]):
            if st > 1:
                Y = Y.reshape(*Y.shape[:i + 2], st, -1,
                              *Y.shape[i + 3:]).mean(i + 2)

            Y = ifft(Y, dim=i + 2)
            Y = Y.as_strided(
                Y.shape[:i + 2] + output_size[i:i + 1] + Y.shape[i + 3:],
                Y.stride())

    if stride[-1] > 1:
        n_fft = Y.size(-1) * 2 - 2
        new_n_fft = n_fft // stride[-1]
        step_size = new_n_fft // 2
        strided_Y_size = step_size + 1

        unfolded_Y_real = Y.real.unfold(-1, strided_Y_size, step_size)
        unfolded_Y_imag = Y.imag[..., 1:].unfold(-1, strided_Y_size - 2,
                                                 step_size)
        Y_pos_real, Y_pos_imag = unfolded_Y_real[..., ::2, :].sum(
            -2), unfolded_Y_imag[..., ::2, :].sum(-2)
        Y_neg_real, Y_neg_imag = unfolded_Y_real[..., 1::2, :].sum(-2).flip(
            -1), unfolded_Y_imag[..., 1::2, :].sum(-2).flip(-1)

        Y_real = Y_pos_real.add_(Y_neg_real)
        Y_imag = Y_pos_imag.add_(Y_neg_imag, alpha=-1)
        Y_imag = F.pad(Y_imag, [1, 1])

        Y = torch.view_as_complex(torch.stack((Y_real, Y_imag),
                                              -1)).div_(stride[-1])

    output = irfft(Y)

    # Remove extra padded values
    output = output[..., :output_size[-1]].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
def _fft_conv_transposend(
    input: Tensor,
    weight: Tensor,
    bias: Optional[Tensor],
    stride: Tuple[int],
    padding: Tuple[int],
    output_padding: Tuple[int],
    groups: int,
    dilation: Tuple[int],
) -> Tensor:

    output_size = _conv_transpose_shape(input.shape[2:], weight.shape[2:],
                                        stride, padding, output_padding,
                                        dilation)
    padded_output_size = tuple(o + 2 * p for o, p in zip(output_size, padding))

    s: List[int] = []
    weight_s: List[int] = []
    for i, (x_size, w_size, d, st) in enumerate(
            zip(padded_output_size, weight.shape[2:], dilation, stride)):
        s_size = max(x_size, w_size * d)

        # find s size that can be divided by stride and dilation
        rfft_even = 2 if i == len(stride) - 1 else 1
        factor = _lcm(st * rfft_even, d * rfft_even)

        offset = s_size % factor
        if offset:
            s_size += factor - offset
        s.append(s_size // st)
        weight_s.append(s_size // d)

    X = rfft(input, n=s[-1])
    W = rfft(weight, n=weight_s[-1])

    if stride[-1] > 1:
        X_neg_freq = X.flip(-1)[..., 1:]
        X_neg_freq.imag.mul_(-1)

        tmp = [X]
        for i in range(1, stride[-1]):
            if i % 2:
                tmp.append(X_neg_freq)
            else:
                tmp.append(X[..., 1:])

        X = torch.cat(tmp, -1)

    if dilation[-1] > 1:
        W_neg_freq = W.flip(-1)[..., 1:]
        W_neg_freq.imag.mul_(-1)

        tmp = [W]
        for i in range(1, dilation[-1]):
            if i % 2:
                tmp.append(W_neg_freq)
            else:
                tmp.append(W[..., 1:])

        W = torch.cat(tmp, -1)

    if len(s) > 1:
        X = fftn(X, s=s[:-1], dim=tuple(range(2, X.ndim - 1)))
        W = fftn(W, s=weight_s[:-1], dim=tuple(range(2, W.ndim - 1)))
        repeats = (1, 1) + stride[:-1] + (1, )
        if sum(repeats) > X.ndim:
            X = X.repeat(*repeats)

        repeats = (1, 1) + dilation[:-1] + (1, )
        if sum(repeats) > W.ndim:
            W = W.repeat(*repeats)

    Y = _complex_matmul(X, W, groups, True)

    output = irfftn(Y, dim=tuple(range(2, Y.ndim)))

    # Remove extra padded values
    index = (slice(None), ) * 2 + tuple(
        slice(p, o + p) for p, o in zip(padding, output_size))
    output = output[index].contiguous()

    # Optionally, add a bias term before returning.
    if bias is not None:
        output += bias[(slice(None), ) + (None, ) * (output.ndim - 2)]

    return output
    def fit(self, dataset, batch_size=128, epochs=10, frames=6):
        """Fits a new MultFRRBM model.
        Args:
            dataset (torch.utils.data.Dataset | Dataset): A Dataset object containing the training data.
            batch_size (int): Amount of samples per batch.
            epochs (list): Number of training epochs per layer.
        Returns:
            MSE (mean squared error) and log pseudo-likelihood from the training step.
        """

        batches = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=workers,
                             collate_fn=collate_fn)

        for ep in range(epochs):
            logger.info(f'Epoch {ep+1}/{epochs}')

            # Resetting epoch's MSE and pseudo-likelihood to zero
            mse, pl, cst = 0, 0, 0

            inner_trans = tqdm.tqdm(total=len(batches),
                                    desc='Batch',
                                    position=1)
            start = time.time()

            for ii, batch in enumerate(batches):
                x, y = batch

                # Checking whether GPU is avaliable and if it should be used
                if self.device == 'cuda':
                    x = x.cuda()

                mse2, pl2, cst2 = 0, 0, 0
                cost, cost2 = 0, 0

                # Initializing the gradient
                #self.models[1].optimizer.zero_grad()

                for fr in range(frames):
                    x_ = x[:, fr, :, :].squeeze()

                    spec_data = fftshift(fftn(x_))[:, :, :, 0]
                    spec_data = torch.abs(spec_data.squeeze())
                    spec_data = spec_data.reshape(spec_data.size(0),
                                                  self.n_visible)
                    spec_data = (
                        (spec_data - torch.mean(spec_data, 0, True)) /
                        (torch.std(spec_data, 0, True) + c.EPSILON)).detach()

                    x_ = x_.reshape(x.size(0), self.n_visible)
                    x_ = ((x_ - torch.mean(x_, 0, True)) /
                          (torch.std(x_, 0, True) + c.EPSILON)).detach()

                    # Performs the Gibbs sampling procedure
                    _, _, _, _, visible_states = self.models[0].gibbs_sampling(
                        spec_data)
                    _, _, _, _, visible_states2 = self.models[
                        1].gibbs_sampling(x_)

                    # Calculates the loss for further gradients' computation
                    cost = torch.mean(
                        self.models[0].energy(spec_data)) - torch.mean(
                            self.models[0].energy(visible_states))
                    cost2 = torch.mean(self.models[1].energy(x_)) - torch.mean(
                        self.models[1].energy(visible_states2))

                    # Initializing the gradient
                    self.models[0].optimizer.zero_grad()
                    self.models[1].optimizer.zero_grad()

                    # Computing the gradients
                    #cost /= frames
                    cost.backward()
                    #cost2 /= frames
                    cost2.backward()

                    # Updating the parameters
                    self.models[0].optimizer.step()
                    self.models[1].optimizer.step()

                    # Detaching the visible states from GPU for further computation
                    visible_states = visible_states.detach()
                    visible_states2 = visible_states2.detach()

                    # Calculating current's batch MSE
                    batch_mse1 = torch.div(
                        torch.sum(torch.pow(spec_data - visible_states, 2)),
                        batch_size).detach()
                    batch_mse2 = torch.div(
                        torch.sum(torch.pow(x_ - visible_states2, 2)),
                        batch_size).detach()

                    # Calculating the current's batch logarithm pseudo-likelihood
                    batch_pl1 = self.models[0].pseudo_likelihood(
                        spec_data).detach()
                    batch_pl2 = self.models[1].pseudo_likelihood(x_).detach()

                    # Summing up to epochs' MSE and pseudo-likelihood
                    mse2 += (batch_mse1 + batch_mse2)
                    pl2 += (batch_pl1 + batch_pl2)
                    cst2 += (cost.detach() + cost2.detach())

                mse2 /= frames
                pl2 /= frames
                cst2 /= frames

                #cost2 /= frames
                #cost2.backward()
                #self.models[1].optimizer.step()

                mse += mse2
                pl += pl2
                cst += cst2

                if ii % 100 == 99:
                    print('MSE:', (mse / ii).item(), 'Cost:',
                          (cst / ii).item())

                    w8 = self.models[0].W.cpu().detach().numpy()
                    img = _rasterize(w8.T,
                                     img_shape=(72, 96),
                                     tile_shape=(30, 30),
                                     tile_spacing=(1, 1))
                    im = Image.fromarray(img)
                    im.save('w8_spec.png')

                    w8 = self.models[1].W.cpu().detach().numpy()
                    img = _rasterize(w8.T,
                                     img_shape=(72, 96),
                                     tile_shape=(30, 30),
                                     tile_spacing=(1, 1))
                    im = Image.fromarray(img)
                    im.save('w8_gauss.png')

                    x = visible_states[:100].cpu().detach().reshape(
                        (100, 6912)).numpy()
                    x = _rasterize(x,
                                   img_shape=(72, 96),
                                   tile_shape=(10, 10),
                                   tile_spacing=(1, 1))
                    im = Image.fromarray(x)
                    im = im.convert("LA")
                    im.save('spectral.png')

                    x = visible_states2[:100].cpu().detach().reshape(
                        (100, 6912)).numpy()
                    x = _rasterize(x,
                                   img_shape=(72, 96),
                                   tile_shape=(10, 10),
                                   tile_spacing=(1, 1))
                    im = Image.fromarray(x)
                    im = im.convert("LA")
                    im.save('sample.png')

                inner_trans.update(1)

            mse /= len(batches)
            pl /= len(batches)
            cst /= len(batches)

            logger.info(
                f'MSE: {mse.item()} | log-PL: {pl.item()} | Cost: {cst.item()}'
            )

            end = time.time()
            self.dump(mse=mse.item(),
                      pl=pl.item(),
                      fe=cst.item(),
                      time=end - start)

        return mse, pl, cst
    def fit(self, dataset, batch_size=128, epochs=10, frames=6):
        """Fits a new RBM model.

        Args:
            dataset (torch.utils.data.Dataset): A Dataset object containing the training data.
            batch_size (int): Amount of samples per batch.
            epochs (int): Number of training epochs.

        Returns:
            MSE (mean squared error) and log pseudo-likelihood from the training step.

        """
        
        # Transforming the dataset into training batches
        batches = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn)

        # For every epoch
        for e in range(epochs):
            logger.info(f'Epoch {e+1}/{epochs}')

            # Calculating the time of the epoch's starting
            start = time.time()

            # Resetting epoch's MSE and pseudo-likelihood to zero
            mse, pl, cst = 0, 0, 0

            # For every batch
            inner = tqdm.tqdm(total=len(batches), desc='Batch', position=1)
            for ii, batch in enumerate(batches):
                samples, _ = batch

                # Checking whether GPU is avaliable and if it should be used
                if self.device == 'cuda':
                    samples = samples.cuda()

                mse2, pl2, cst2 = 0, 0, 0
                cost = 0

                # Initializing the gradient
                self.optimizer.zero_grad()

                for fr in range(frames):
                    #torch.autograd.set_detect_anomaly(True)                    
                    sps = samples[:, fr, :, :].squeeze()

                    # Creating the Fourier Spectrum
                    spec_data = fftshift(fftn(sps))[:,:,:,0]                    
                    spec_data = torch.abs(spec_data.squeeze())
                    
                    # Flattening the samples' batch
                    spec_data = spec_data.view(spec_data.size(0), self.n_visible)
                
                    # Normalizing the samples' batch
                    spec_data = ((spec_data - torch.mean(spec_data, 0, True)) / (torch.std(spec_data, 0, True) + c.EPSILON)).detach()

                    # Performs the Gibbs sampling procedure
                    _, _, _, _, visible_states = self.gibbs_sampling(spec_data)

                    # Calculates the loss for further gradients' computation
                    cost += torch.mean(self.energy(spec_data)) - \
                            torch.mean(self.energy(visible_states))

                    # Detaching the visible states from GPU for further computation
                    visible_states = visible_states.detach()

                    # Gathering the size of the batch
                    batch_size2 = sps.size(0)

                    # Calculating current's batch MSE
                    batch_mse = torch.div(
                        torch.sum(torch.pow(spec_data - visible_states, 2)), batch_size2).detach()

                    # Calculating the current's batch logarithm pseudo-likelihood
                    batch_pl = self.pseudo_likelihood(spec_data).detach()

                    # Summing up to epochs' MSE and pseudo-likelihood
                    mse2 += batch_mse
                    pl2 += batch_pl
                    cst2 += cost.detach()

                # Computing the gradients
                cost /= frames
                cost.backward()

                # Updating the parameters
                self.optimizer.step()

                mse2 /= frames
                pl2 /= frames
                cst2 /= frames

                mse += mse2
                pl  += pl2
                cst += cst2

                if ii % 100 == 99:
                    print('MSE:', (mse/ii).item(), 'Cost:', (cst/ii).item())
                    w8 = self.W.cpu().detach().numpy()
                    img = _rasterize(w8.T, img_shape=(72, 96), tile_shape=(30, 30), tile_spacing=(1, 1))
                    im = Image.fromarray(img)
                    im.save('w8_spec.png')

                    x = visible_states[:100].cpu().detach().reshape((100, 6912)).numpy()
                    x = _rasterize(x, img_shape=(72, 96), tile_shape=(10, 10), tile_spacing=(1, 1))
                    im = Image.fromarray(x)
                    im = im.convert("LA")
                    im.save('spectral.png')

                inner.update(1)

            # Normalizing the MSE and pseudo-likelihood with the number of batches
            mse /= len(batches)
            pl /= len(batches)
            cst /= len(batches)

            # Calculating the time of the epoch's ending
            end = time.time()

            # Dumps the desired variables to the model's history
            self.dump(mse=mse.item(), pl=pl.item(), fe=cst.item(), time=end-start)

            logger.info(f'MSE: {mse} | log-PL: {pl} | Cost: {cst}')
        self.p = 0
        return mse, pl, cst
    def reconstruct(self, dataset, bs=2**7):
        """Reconstructs batches of new samples.

        Args:
            dataset (torch.utils.data.Dataset): A Dataset object containing the testing data.

        Returns:
            Reconstruction error and visible probabilities, i.e., P(v|h).

        """

        logger.info(f'Reconstructing new samples ...')

        # Resetting MSE to zero
        mse = 0

        # Defining the batch size as the amount of samples in the dataset
        batch_size = bs

        # Transforming the dataset into training batches
        batches = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn)
        
        # For every batch
        inner = tqdm.tqdm(total=len(batches), desc='Batch', position=1)
        for _, batch in enumerate(batches):
            x, _ = batch
            frames = x.size(1) #frames            
            dy, dx = x.size(2), x.size(3)
            reconstructed = torch.zeros((bs, frames, self.n_visible))

            # Checking whether GPU is avaliable and if it should be used
            if self.device == 'cuda':
                # Applies the GPU usage to the data            
                x = x.cuda()
                reconstructed = reconstructed.cuda()                

            for fr in range(frames):
                sps = x[:, fr, :, :].squeeze()

                # Creating the Fourier Spectrum
                spec_data = fftshift(fftn(sps))[:,:,:,0]
                spec_data = torch.abs(spec_data.squeeze())
                spec_data.detach()
                
                # Flattening the samples' batch
                spec_data = spec_data.view(spec_data.size(0),self.n_visible)

                # Normalizing the samples' batch
                spec_data = ((spec_data - torch.mean(spec_data, 0, True)) / (torch.std(spec_data, 0, True) + c.EPSILON)).detach()
            
                # Performs the Gibbs sampling procedure
                _, _, _, _, visible_states = self.gibbs_sampling(spec_data)

                visible_states = visible_states.detach()
                
                # Passing reconstructed data to a tensor
                reconstructed[:, fr, :] = visible_states

            # Calculating current's batch reconstruction MSE
            batch_mse = torch.div(
                torch.sum(torch.pow(x.reshape((len(x), frames, dy*dx)) - visible_states, 2)), bs).detach()

            # Summing up the reconstruction's MSE
            mse += batch_mse

            inner.update(1)
            break

        # Normalizing the MSE with the number of batches
        mse /= len(batches)
        logger.info(f'MSE: {mse}')

        return mse, x, reconstructed
Exemple #21
0
def ufft(data: Tensor,
         mask: Tensor,
         signal_ndim: int,
         normalized: bool = False) -> Tensor:
    """Undersampled fast Fourier transform."""
    return mask * fftn(data, dim=2, norm='ortho')
                    (weights * diff)**2)
                loss_l2 = rho * 0.5 * torch.sum(
                    (x - outputs[0, 0, ...] + mu)**2)
                loss = loss_fidelity + loss_l2
                # loss = loss_fidelity
                loss.backward()
                return loss

            optimizer.step(closure)

            # forward again to compute fidelity loss
            outputs = resnet(inputs_cat)
            outputs_cplx = outputs.type(torch.complex64)
            # loss
            RDFs_outputs = torch.real(
                fft.ifftn((fft.fftn(outputs_cplx, dim=[2, 3, 4]) * D),
                          dim=[2, 3, 4]))
            diff = torch.abs(rdfs - RDFs_outputs)
            loss_fidelity = torch.sum((weights * diff)**2)
            fidelity_fine = 'epochs: [%d/%d], Ks: [%d/%d], time: %ds, Fidelity loss: %f' % (
                epoch, niter, k + 1, K, time.time() - t0, loss_fidelity.item())
            print(fidelity_fine)
            if k == K - 1:
                file.write(fidelity_fine)
                file.write('\n')

        # dual update
        with torch.no_grad():
            mu = mu + x - outputs[0, 0, ...]

        # # metrics
Exemple #23
0
    def _compute_variances(self,
                           Phi: Projection,
                           alpha: Tensor,
                           kmasks: Tensor,
                           num_probes: int,
                           num_cg_iters: int = 32,
                           cg_tol: float = 1e-10) -> Tensor:

        num_contrasts, size_x, size_y = kmasks.size()
        masks = ~kmasks.unsqueeze(dim=0)

        z = self._samp_probes((num_probes, num_contrasts, size_x, size_y))
        b = fftn(z, dim=(-2, -1), norm='ortho')
        b = masks * b

        b_lst = []
        norm = torch.tensor(0., device=self.device)

        if 'x' in self.grad_dim:
            kx = torch.arange(size_x, device=self.device).view(1, 1, -1, 1)
            kfactor_x = (1 - torch.exp(-2 * np.pi * 1j * kx / size_x))
            b_x = kfactor_x * b
            b_lst.append(b_x)
            norm = norm + torch.abs(kfactor_x)**2

        if 'y' in self.grad_dim:
            ky = torch.arange(size_y, device=self.device).view(1, 1, 1, -1)
            kfactor_y = (1 - torch.exp(-2 * np.pi * 1j * ky / size_y))
            b_y = kfactor_y * b
            b_lst.append(b_y)
            norm = norm + torch.abs(kfactor_y)**2

        corr = torch.zeros((1, 1, size_x, size_y), device=self.device)
        if self.grad_dim == 'x':
            corr[0, 0, 0, :] = 1
        elif self.grad_dim == 'y':
            corr[0, 0, :, 0] = 1
        elif self.grad_dim == 'xy':
            corr[0, 0, 0, 0] = 1
        norm = norm + corr
        b = torch.stack(b_lst, dim=1) / norm.unsqueeze(dim=1)

        b = ifftn(b, dim=(-2, -1), norm='ortho').real
        b = b.flatten(start_dim=-2)

        alpha = alpha.unsqueeze(dim=1).unsqueeze(dim=0)
        A = lambda x: self.alpha0 * (Phi.T(Phi(x))) + alpha * x
        out, _ = conjugate_gradient(A, b, -1, num_cg_iters, cg_tol)

        out = out.unflatten(dim=-1, sizes=(size_x, size_y))
        out = fftn(out, dim=(-2, -1), norm='ortho')

        if 'x' in self.grad_dim:
            out[:, 0] = torch.conj(kfactor_x) * out[:, 0] / norm
        if 'y' in self.grad_dim:
            out[:, -1] = torch.conj(kfactor_y) * out[:, -1] / norm

        if self.grad_dim == 'xy':
            out = out[:, 0] + out[:, -1]
        else:
            out = out.squeeze(dim=1)

        out = masks * out
        out = ifftn(out, dim=(-2, -1), norm='ortho').real

        var = (z * out).mean(dim=0).clamp(min=0)
        return var
Exemple #24
0
 def apply(self, x: Tensor) -> Tensor:
     y = fftn(x, dim=(-2, -1), norm='ortho')
     y[..., ~self.mask] = 0
     return y
def convolve_fft(array, kernel, axes=None):
    arrayfft = fftn(array, dim=axes)
    kernelfft = fftn(ifftshift(kernel, axes=axes), dim=axes)
    fftmult = kernelfft * arrayfft
    return torch.real(ifftn(fftmult, dim=axes))
Exemple #26
0
    match_dict['feature.2.weight'] = 'conv2_w'
    match_dict['feature.2.bias'] = 'conv2_b'

    for var_name in net.state_dict().keys():
        print(var_name)
        key_in_model = match_dict[var_name]
        param_in_model = var_name.rsplit('.', 1)[1]
        if 'weight' in var_name:
            pth_state_dict[var_name] = torch.Tensor(
                np.transpose(p[key_in_model], (3, 2, 0, 1)))
        elif 'bias' in var_name:
            pth_state_dict[var_name] = torch.Tensor(np.squeeze(
                p[key_in_model]))
        if var_name == 'feature.0.weight':
            weight = pth_state_dict[var_name].data.numpy()
            weight = weight[:, ::-1, :, :].copy()  # cv2 bgr input
            pth_state_dict[var_name] = torch.Tensor(weight)

    torch.save(pth_state_dict, 'param.pth')
    net.load_state_dict(torch.load('param.pth'))
    x_t = torch.Tensor(np.expand_dims(np.transpose(x, (2, 0, 1)), axis=0))
    x_pred = net(x_t).data.numpy()
    pred_error = np.sum(
        np.abs(
            np.transpose(x_pred, (0, 2, 3, 1)).reshape(-1) -
            x_out.reshape(-1)))

    x_fft = fft.fftn(x_t, dim=[-2, -1])

    print('model_transfer_error:{:.5f}'.format(pred_error))
    t = cn.mulconj(xf, zf)
    kxzf = torch.sum(t, dim=1, keepdim=True)

    # [batch, 1, 121, 61, 2]
    alphaf = label.to(device=z.device) / (kzzf + lambda0)

    # [batch, 1, 121, 121]
    return torch.irfft(cn.mul(kxzf, alphaf), signal_ndim=2)


##############################################

x = torch.rand((42, 32, 121, 121))

a = torch.rfft(x, signal_ndim=2, onesided=False)
b = fft.fftn(x, dim=[-2, -1])

ca = torch.view_as_complex(a)
print(a.shape)
print(b.shape)
print(torch.allclose(ca, b))

u = ca - b
v = u.abs()
h = torch.histc(v)

import matplotlib.pyplot as plt
plt.hist(v.flatten().numpy(), bins=500, log=True)
plt.show()

exit()
Exemple #28
0
            output = _fft(input, 1, normalized=(norm == 'ortho'))
        if norm == 'forward':
            output /= float(n)

        # Make complex and move back dimension to its original position
        if _torch_has_complex:
            output = torch.view_as_complex(output)
            output = utils.movedim(output, -1, dim)
        else:
            output = utils.movedim(output, -2, dim if dim >= 0 else dim - 1)

        return output


if _torch_has_fft_module:
    fftn = lambda *a, real=None, **k: fft_mod.fftn(*a, **k)
else:

    def fftn(input, s=None, dim=None, norm='backward', real=None):
        """N-dimensional discrete Fourier transform.

        Parameters
        ----------
        input : tensor
            Input signal.
            If torch <= 1.5, the last dimension must be of length 2 and
            contain the real and imaginary parts of the signal, unless
            `real is True`.
        s : sequence[int], optional
            Signal size in the transformed dimensions.
            If given, each dimension dim[i] will either be zero-padded or
Exemple #29
0
    def _compute_imgs(self, grad: Tensor, kspaces: Tensor) -> Tensor:
        _, size_x, size_y = self.kspaces.size()
        num_grads, num_contrasts, _ = grad.size()
        grad = grad.view((self.num_grads, num_contrasts, size_x, size_y))
        grad_old = grad

        if self.complex_imgs and self.tie_real_imag:
            num_contrasts //= 2
            grad = grad[:, :num_contrasts] + 1j * grad[:, num_contrasts:]
            kspaces = kspaces[:num_contrasts]
        elif self.complex_imgs and not self.tie_real_imag:
            num_grads = self.num_grads // 2
            grad = grad[:num_grads] + 1j * grad[num_grads:]

        img_fft = torch.tensor(0., device=self.device)
        norm = torch.tensor(0., device=self.device)

        if 'x' in self.grad_dim:
            kx = torch.arange(size_x, device=self.device).view(1, -1, 1)
            kfactor_x = (1 - torch.exp(-2 * np.pi * 1j * kx / size_x))

            grad_x = grad[0]
            grad_x_fft = fftn(grad_x, dim=(-2, -1), norm='ortho')

            img_fft = img_fft + torch.conj(kfactor_x) * grad_x_fft
            norm = norm + torch.abs(kfactor_x)**2

        if 'y' in self.grad_dim:
            ky = torch.arange(size_y, device=self.device).view(1, 1, -1)
            kfactor_y = (1 - torch.exp(-2 * np.pi * 1j * ky / size_y))

            grad_y = grad[-1]
            grad_y_fft = fftn(grad_y, dim=(-2, -1), norm='ortho')

            img_fft = img_fft + torch.conj(kfactor_y) * grad_y_fft
            norm = norm + torch.abs(kfactor_y)**2

        corr = torch.zeros((1, size_x, size_y), device=self.device)
        if self.grad_dim == 'x':
            corr[0, 0, :] = 1
        elif self.grad_dim == 'y':
            corr[0, :, 0] = 1
        elif self.grad_dim == 'xy':
            corr[0, 0, 0] = 1
        norm = norm + corr

        img_fft = img_fft / norm * (self.kspaces == 0) + self.kspaces
        img = ifftn(img_fft, dim=(-2, -1), norm='ortho')
        # img.real = img.real.clamp(min=0, max=1)
        # img.imag = img.imag.clamp(min=0, max=1)

        if self.normalize:
            img = img * self.scale + self.bias

        if not self.complex_imgs:
            img = img.real

        # if self.complex_imgs:
        #     num_contrasts //= 2
        #     img_real = img[:num_contrasts]
        #     img_imag = img[num_contrasts:]
        #     img = img_real + 1j * img_imag

        return img