def overlap_and_add(signal, frame_step): """Reconstructs a signal from a framed representation. Adds potentially overlapping frames of a signal with shape `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. The resulting tensor has shape `[..., output_size]` where output_size = (frames - 1) * frame_step + frame_length Args: signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. Returns: A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. output_size = (frames - 1) * frame_step + frame_length Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py """ outer_dimensions = signal.size()[:-2] frames, frame_length = signal.size()[-2:] subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor subframe_step = frame_step // subframe_length subframes_per_frame = frame_length // subframe_length output_size = frame_step * (frames - 1) + frame_length output_subframes = output_size // subframe_length subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) frame = signal.new_tensor(frame).long() # signal may in GPU or CPU frame = frame.contiguous().view(-1) result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) result.index_add_(-2, frame, subframe_signal) result = result.view(*outer_dimensions, -1) return result
def augment(signal, target, increase_size): """Function to augment the data Args: input (torch.FloatTensor) : Tensor with dimensions as (number of samples x rows x columns) target (torch.FloatTensor) : Tensor with dimensions as (number of samples x 1) increase_size (int) : the number of data to be increased """ tmp = torch.zeros(increase_size, signal.size(1), signal.size(2)) aug_target = torch.zeros(increase_size) ind_max = signal.size(0) ls1 = np.where(target.data.numpy() == 1)[0] ls2 = np.where(target.data.numpy() == 0)[0] m = int(increase_size / 2) aug_target[0:m] = 1 aug_target[m:] = 0 for i in range(m): a = signal[random.choice(ls1), :, :] b = signal[random.choice(ls1), :, :] tmp[i, :, :] = (a + b) / 2 for i in range(m, increase_size): a = signal[random.choice(ls2), :, :] b = signal[random.choice(ls2), :, :] tmp[i, :, :] = (a + b) / 2 return torch.cat((signal, tmp), 0), torch.cat( (target, Variable(aug_target.long())), 0)
def ctp_s0(signal, config, device): ''' Calculate the CTP bolus arrival time (bat) and corresponding S0: averaged over signals before bat return: s0 # (n_slice, n_row, n_column) ''' sig_avg = torch.zeros([signal.size()[3]], device=device, dtype=torch.float, requires_grad=False) for t in range(signal.size()[3]): sig_avg[t] = torch.mean(signal[..., t]) threshold = config.ctp_s0_threshold * (torch.max(sig_avg) - torch.min(sig_avg)) flag = True bat = 1 while flag: if sig_avg[bat] - sig_avg[bat - 1] >= threshold and sig_avg[ bat + 1] > sig_avg[bat]: flag = False else: bat += 1 if bat == signal.size()[3]: flag = False print(' Bolus arrival time (start from 0):', bat - 1) s0 = torch.mean(signal[..., :bat], dim=3) # time dimension == 3 return s0, bat
def mrp_s0(signal, config, device): ''' Calculate the MRP bolus arrival time (bat) and corresponding S0: averaged over signals before bat return: s0 # (n_slice, n_row, n_column) ''' sig_avg = torch.zeros([signal.size()[3]], device=device, dtype=torch.float, requires_grad=False) for t in range(signal.size()[3]): sig_avg[t] = torch.mean(signal[..., t]) flag = True bat = 0 while flag: s0_avg = torch.mean(sig_avg[:bat + 1]) if torch.abs(s0_avg - sig_avg[bat + 1]) / s0_avg < config.mrp_s0_threshold: bat += 1 else: flag = False bat -= 1 if bat == signal.size()[3] - 1: flag = False bat -= 1 print(' Bolus arrival time (start from 0):', bat) s0 = torch.mean(signal[..., :bat], dim=3) # time dimension == 3 return s0, bat
def ct2ctc(signal, config, device): s0, _ = ctp_s0(signal, config, device) ctc = torch.zeros(signal.size(), device=device, dtype=torch.float, requires_grad=False) for t in range(signal.size()[3]): ctc[..., t] = config.k_ct * (signal[..., t] - s0) # Check computed CTC: should have no NaN value if not len(torch.nonzero(torch.isnan(ctc))) == 0: raise ValueError('Computed CTC contains NaN value, check out!') return ctc
def mr2ctc(signal, config, device): # TODO: use mask if needed s0, _ = mrp_s0(signal, config, device) ctc = torch.zeros(signal.size(), device=device, dtype=torch.float, requires_grad=False) for t in range(signal.size()[3]): ctc[..., t] = -config.k_mr / config.TE * torch.log(signal[..., t] / s0) # Check computed CTC: should have no NaN value if not len(torch.nonzero(torch.isnan(ctc))) == 0: raise ValueError('Computed CTC contains NaN value, check out!') return ctc