def forward(self, x): residual = x y1 = residual if self.downsample is not None: y1 = self.downsample(residual) y2 = F.conv2d(residual, self.weight1) y2 = self.bn1(y2) y2 = F.softshrink(y2, self.tau) w = self.weight2.repeat(self.out_chnls, 1, 1, 1) y2 = F.conv2d(y2, w, stride=self.stride, padding=self.padding, groups=self.out_chnls) y2 = self.bn2(y2) y2 = self.activation(y2) y2 = self.conv3(y2) y2 = self.bn3(y2) y = y1 + y2 y = F.softshrink(y, self.tau) return y
def encode(model, x): n = x.shape[1] N = x.shape[0] x = x.t() tol = 1e-10 D = model.D.detach() lambd_eta = lambd = model.eta * model.lambd Dtx = torch.matmul(D.t(), x).detach() z_ = F.softshrink(model.eta * Dtx, lambd_eta) for i in range(model.T): res = torch.matmul(D, z_) - x Dtres = torch.matmul(D.t(), res).detach() z = F.softshrink(z_ - model.eta * Dtres, lambd=model.lambd * model.eta) if torch.norm(z - z_) < tol: break else: z_ = z # computing closed form solution given sign pattern alfa = z SIGNS = torch.sign(z) for i in range(N): S = (z[:, i] != 0) if np.sum(S.numpy()) == 0: break else: Gs = torch.matmul(model.D[:, S].t(), model.D[:, S]) b = (torch.matmul(model.D[:, S].t(), x[:, i]) - model.lambd * SIGNS[S, i]) alfa[S, i] = torch.matmul(torch.inverse(Gs), b) return alfa
def coord_descent(x, W, z0=None, alpha=1.0, maxiter=1000, tol=1e-6, verbose=False): input_dim, code_dim = W.shape # [D,K] batch_size, input_dim1 = x.shape # [N,D] assert input_dim1 == input_dim tol = tol * code_dim if z0 is None: z = x.new_zeros(batch_size, code_dim) # [N,K] else: assert z0.shape == (batch_size, code_dim) z = z0 # initialize b # TODO: how should we initialize b when 'z0' is provided? b = torch.mm(x, W) # [N,K] # precompute S = I - W^T @ W S = - torch.mm(W.T, W) # [K,K] S.diagonal().add_(1.) # loss function def fn(z): x_hat = torch.matmul(z, W.T) loss = 0.5 * (x_hat - x).pow(2).sum() + alpha * z.abs().sum() return loss # update function def cd_update(z, b): z_next = F.softshrink(b, alpha) # [N,K] z_diff = z_next - z # [N,K] k = z_diff.abs().argmax(1) # [N] kk = k.unsqueeze(1) # [N,1] b = b + S[:,k].T * z_diff.gather(1, kk) # [N,K] += [N,K] * [N,1] z = z.scatter(1, kk, z_next.gather(1, kk)) return z, b active = torch.arange(batch_size, device=W.device) for i in range(maxiter): if len(active) == 0: break z_old = z[active] z_new, b[active] = cd_update(z_old, b[active]) update = (z_new - z_old).abs().sum(1) z[active] = z_new active = active[update > tol] if verbose: print('iter %i - loss: %0.4f' % (i, fn(F.softshrink(b, alpha)))) z = F.softshrink(b, alpha) return z
def fista_forward(self, data, return_loss=False): """ Implements FISTA. """ if return_loss: loss = [] # Initializations. yk = self.We(data) xprev = torch.zeros(yk.size()).to(self.device) t = 1 # Iterations. for it in range(self.n_iter): # Update logistics. residual = F.linear(yk, self.We.weight.t()) - data # ISTA step tmp = yk - self.We(residual) / self.L xk = F.softshrink(tmp, lambd=self.thresh) # FISTA stepsize update: tnext = (1 + (1 + 4 * (t**2))**.5) / 2 fact = (t - 1) / tnext # Use momentum to update code estimate. yk = xk + (xk - xprev) * fact # Update "prev" stuff. xprev = xk t = tnext if return_loss: loss += [self.lossFcn(yk, data)] # Fin. if return_loss: return yk, loss return yk
def forward(self, input, k, d): hks = k.shape[-1] // 2 hds = d.shape[-1] // 2 x_padding = (hks, hks, hks, hks) r_padding = (hds, hds, hds, hds) output = [] for c in range(input.size(1)): y = input[:, c].unsqueeze(1) x = y.clone() for i in range(self.n_iter): # z update z = F.conv2d(F.pad(x, (2, 2, 2, 2), 'replicate'), self.weight) z = F.softshrink(z, self.lambd / max(1e-4, self.beta[i].item())) # x update for j in range(self.n_in): r0 = y - F.conv2d(F.pad(x, x_padding, 'replicate'), k) r1 = z - F.conv2d(F.pad(x, (2, 2, 2, 2), 'replicate'), self.weight) r = torch.cat([r0, r1], dim=1) r_pad = F.pad(r, r_padding, 'replicate') for l in range(3): x = x + F.conv2d(r_pad[:, l].unsqueeze(0), d[i, l].unsqueeze(0).unsqueeze(0)) x = x.clamp(0, 1) output.append(x.clone()) output = torch.cat(output, 1) return output
def resample(image: torch.Tensor, displacement, image_is_displacement=False): shape = displacement[0].shape grid = torch.meshgrid( torch.linspace(-1, 2 * shape[0] / image.shape[0] - 1, shape[0]).float().cuda(), torch.linspace(-1, 2 * shape[1] / image.shape[1] - 1, shape[1]).float().cuda()) df = torch.stack([ torch.add(grid[1 - i], 2 / image.shape[1 - i], displacement[i]) for i in range(2) ], 2)[None, ] image = image[None, None, ] if not image_is_displacement: out = F.grid_sample(image, df, padding_mode='zeros') else: out = F.grid_sample(image, df, padding_mode='border') dx = torch.mean((image[:, :, :, -1] - image[:, :, :, 0]), 2) / 2 dy = torch.mean((image[:, :, -1, :] - image[:, :, 0, :]), 2) / 2 pad_offset = F.softshrink(df, 1) * torch.stack( (dx, dy), 2)[:, :, None, None, :] pad_offset = torch.sum(pad_offset, 4) out += pad_offset return out[0][0]
def cd_update(z, b): z_next = F.softshrink(b, alpha) # [N,K] z_diff = z_next - z # [N,K] k = z_diff.abs().argmax(1) # [N] kk = k.unsqueeze(1) # [N,1] b = b + S[:,k].T * z_diff.gather(1, kk) # [N,K] += [N,K] * [N,1] z = z.scatter(1, kk, z_next.gather(1, kk)) return z, b
def evaluate(args, epoch, model, data_loader, writer): model.eval() losses = [] start = time.perf_counter() with torch.no_grad(): if epoch != 0: for iter, data in enumerate(data_loader): input, target, mean, std, norm = data input = input.to(args.device) target = target.to(args.device) output = model(input) # output = transforms.complex_abs(output) # complex to real # output = transforms.root_sum_of_squares(output, dim=1) output = output.squeeze() loss = F.l1_loss(output, target) losses.append(loss.item()) x = model.module.get_trajectory() v, a = get_vel_acc(x) acc_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(a, args.a_max), 2))) vel_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(v, args.v_max), 2))) rec_loss = np.mean(losses) writer.add_scalar('Rec_Loss', rec_loss, epoch) writer.add_scalar('Acc_Loss', acc_loss.detach().cpu().numpy(), epoch) writer.add_scalar('Vel_Loss', vel_loss.detach().cpu().numpy(), epoch) writer.add_scalar('Total_Loss', rec_loss + acc_loss.detach().cpu().numpy() + vel_loss.detach().cpu().numpy(), epoch) x = model.module.get_trajectory() v, a = get_vel_acc(x) if args.TSP and epoch < args.TSP_epoch: writer.add_figure('Scatter', plot_scatter(x.detach().cpu().numpy()), epoch) else: writer.add_figure('Trajectory', plot_trajectory(x.detach().cpu().numpy()), epoch) writer.add_figure('Scatter', plot_scatter(x.detach().cpu().numpy()), epoch) writer.add_figure('Accelerations_plot', plot_acc(a.cpu().numpy(), args.a_max), epoch) writer.add_figure('Velocity_plot', plot_acc(v.cpu().numpy(), args.v_max), epoch) writer.add_text('Coordinates', str(x.detach().cpu().numpy()).replace(' ', ','), epoch) if epoch == 0: return None, time.perf_counter() - start else: return np.mean(losses), time.perf_counter() - start
def __init__(self, num_basis, kernel_size, lam, solver_type='fista_custom', solver_params=None, size_average_recon=False, size_average_l1=False, im_size=None, legacy_bias=False): assert not size_average_recon assert not size_average_l1 super().__init__() # TODO: consistently change the semantic of size_average # or I always use size_average_* = False and handle this later on. self.lam = lam self.linear_module = ConvTranspose2d(num_basis, 1, kernel_size, bias=legacy_bias) # the official implement has bias. if legacy_bias: self.linear_module.bias.data.zero_() # TODO handle weighted version. self.cost = MSELoss(reduction='sum') self.size_average_l1 = size_average_l1 self.solver_type = solver_type self.solver_params = deepcopy(solver_params) # save a template for later use. assert im_size is not None self.register_buffer('_template_weight', generate_weight_template(im_size, im_size, kernel_size)) # self.register_buffer('_template_weight', Variable(Tensor(np.ones((25, 25))))) # define the function for fista_custom. def f_fista(target: Tensor, code: Tensor, calc_grad): with torch.no_grad(): recon = self.linear_module(code) cost_recon = 0.5 * self.cost(recon, self._template_weight * target).item() if not calc_grad: return cost_recon else: # compute grad. grad_this = conv2d(recon - self._template_weight * target, self.linear_module.weight, None) return cost_recon, grad_this def g_fista(x: Tensor): cost = torch.abs(x) * self.lam if self.size_average_l1: cost = cost.mean() else: cost = cost.sum() return cost self.f_fista = f_fista self.g_fista = g_fista self.pl = lambda x, L: softshrink(x, lambd=self.lam / L).data if self.solver_type == 'fista_custom': self.L = 0.1 # init L
def evaluate(args, epoch, model, data_loader, writer): model.eval() losses = [] psnrs= [] start = time.perf_counter() with torch.no_grad(): if epoch != 0: for iter, data in enumerate(data_loader): input, target, mean, std, norm = data input = input.unsqueeze(1).to(args.device) target = target.to(args.device) output = model(input).squeeze(1) outputnorm,_,_=transforms.normalize_instance(output, eps=1e-11) psnrs.append( psnr(target.cpu().numpy(), outputnorm.cpu().numpy())) loss = args.rec_weight * F.l1_loss(output, target) losses.append(loss.item()) x = model.get_trajectory() v, a = get_vel_acc(x) acc_loss = args.acc_weight * torch.sqrt(torch.sum(torch.pow(F.softshrink(a, args.a_max), 2))) vel_loss = args.vel_weight * torch.sqrt(torch.sum(torch.pow(F.softshrink(v, args.v_max), 2))) rec_loss = np.mean(losses) psnr_avg = np.mean(psnrs) writer.add_scalar('Rec_Loss', rec_loss, epoch) writer.add_scalar('Acc_Loss', acc_loss.detach().cpu().numpy()/args.acc_weight, epoch) writer.add_scalar('Vel_Loss', vel_loss.detach().cpu().numpy()/args.vel_weight, epoch) writer.add_scalar('Acc_Weight',args.acc_weight,epoch) writer.add_scalar('Total_Loss', rec_loss + acc_loss.detach().cpu().numpy() + vel_loss.detach().cpu().numpy(), epoch) writer.add_scalar('PSNR', psnr_avg, epoch) x = model.get_trajectory() v, a = get_vel_acc(x) writer.add_figure('Trajectory_Proj_XY', plot_trajectory(x.detach().cpu().numpy(),0,1), epoch) writer.add_figure('Trajectory_Proj_YZ', plot_trajectory(x.detach().cpu().numpy(),1,2), epoch) writer.add_figure('Trajectory_Proj_XZ', plot_trajectory(x.detach().cpu().numpy(),0,2), epoch) writer.add_figure('Trajectory_3D', plot_trajectory(x.detach().cpu().numpy(),d3=True), epoch) writer.add_figure('Accelerations_plot', plot_acc(a.cpu().numpy(), args.a_max), epoch) writer.add_figure('Velocity_plot', plot_acc(v.cpu().numpy(), args.v_max), epoch) writer.add_text('Coordinates', str(x.cpu().numpy()).replace(' ', ','), epoch) return np.mean(losses), time.perf_counter() - start
def step(self, normalizer=None, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] lamda = group['lamda'] normalize = group['normalize'] for p in group['params']: if p.grad is None: continue if normalizer is None: d_p = p.grad.data else: d_p = p.grad.data.div(normalizer) if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state[ 'momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf p.data.add_(-group['lr'], d_p) if lamda != 0: p.data = softshrink(p.data, group['lr'] * lamda) if normalize: p.data.div_(norm(p.data)) return loss
def step(self, closure=None): """Performs a single optimization step (parameter update). Arguments: closure (callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. """ loss = None if closure is not None: loss = closure() for pg in self.param_groups: lr = pg['lr'] g = pg['gravity'] K = pg['truncate_freq'] weight_decay = pg['weight_decay'] truncate_state = self.state['truncate_state'] if 'index' not in truncate_state: truncate_state['index'] = 1 truncate_index = truncate_state['index'] for p in pg['params']: if p.grad is None: continue # # =============== debug code start =============== # if 'global_step' not in self.state: # self.state['global_step'] = 0 # else: # self.state['global_step'] += 1 # if self.state['global_step'] % 20 == 0: # print('-'*5 + 'Enter TruncateSGD step') # print('index:{}, param.size:{}, lr:{}, g:{}, K:{}, weight_decay:{}'.format( # truncate_index, p.size(), lr, g, K, weight_decay)) # num_nz_mask = len(p.data.nonzero()) # print('non-zero coefficient mask {}'.format(num_nz_mask)) # # =============== debug code end =============== # gradient step p_grad = p.grad.data if weight_decay != 0: p_grad.add_(weight_decay, p.data) p.data.add_(-lr, p_grad) # truncate step if truncate_index > 0 and truncate_index % K == 0: shrink_param = lr * K * g p.data.copy_(F.softshrink(p.data, lambd=shrink_param).data) truncate_state['index'] = 1 else: truncate_state['index'] += 1 return loss
def train_epoch(args, epoch, model, data_loader, optimizer, writer): model.train() avg_loss = 0. if epoch>=args.weight_increase_epoch: args.vel_weight *= 1.5 args.acc_weight *= 1.5 start_epoch = start_iter = time.perf_counter() for iter, data in enumerate(data_loader): input, target, mean, std, norm = data input = input.unsqueeze(1).to(args.device) target = target.to(args.device) output = model(input).squeeze(1) x = model.get_trajectory() v, a = get_vel_acc(x) #During training, adjust the acc,vel for the coarse trajectory. acc_loss = args.acc_weight * torch.sqrt(torch.sum( torch.pow(F.softshrink(a, args.a_max * (args.realworld_points_per_shot / args.points_per_shot) ** 2 / 3), 2))) vel_loss = args.vel_weight * torch.sqrt(torch.sum( torch.pow(F.softshrink(v, args.v_max * args.realworld_points_per_shot / args.points_per_shot / 2), 2))) rec_loss = args.rec_weight * F.l1_loss(output, target) loss = rec_loss + vel_loss + acc_loss optimizer.zero_grad() loss.backward() if args.initialization == '2dstackofstars': x.grad[:,:,2]=0 optimizer.step() avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item() if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} ' f'rec_loss: {rec_loss:.4g}, vel_loss: {vel_loss:.4g}, acc_loss: {acc_loss:.4g}, ' ) start_iter = time.perf_counter() return avg_loss, time.perf_counter() - start_epoch
def __init__(self, **kwargs): super().__init__(**kwargs) self.init_zero_knot_indexes() self.init_derivative_filters() # Initializing the spline activation functions # tensor with locations of spline coefficients grid_tensor = self.grid_tensor # size: (num_activations, size) coefficients = torch.zeros_like(grid_tensor) # spline coefficients # The coefficients are initialized with the value of the activation # at each knot (c[k] = f[k], since B1 splines are interpolators). if self.init == 'even_odd': # initalize half of the activations with an even function (abs) and # and the other half with an odd function (soft threshold). half = self.num_activations // 2 coefficients[0:half, :] = (grid_tensor[0:half, :]).abs() coefficients[half::, :] = F.softshrink(grid_tensor[half::, :], lambd=0.5) elif self.init == 'relu': coefficients = F.relu(grid_tensor) elif self.init == 'leaky_relu': coefficients = F.leaky_relu(grid_tensor, negative_slope=0.01) elif self.init == 'softplus': coefficients = F.softplus(grid_tensor, beta=3, threshold=10) elif self.init == 'random': coefficients.normal_() elif self.init == 'identity': coefficients = grid_tensor.clone() elif self.init != 'zero': raise ValueError( 'init should be even_odd/relu/leaky_relu/softplus/' 'random/identity/zero]') # Need to vectorize coefficients to perform specific operations self.coefficients_vect = nn.Parameter( coefficients.contiguous().view(-1)) # size: (num_activations*size) # Create the finite-difference matrix self.init_D() # Flag to keep track of sparsification process self.sparsification_flag = False
def backtracking(z, x, weight, alpha, lr0, eta=1.5, maxiter=1000, verbose=False): if eta <= 1: raise ValueError('eta must be > 1.') # store initial values resid_0 = torch.matmul(z, weight.T) - x fval_0 = 0.5 * resid_0.pow(2).sum() fgrad_0 = torch.matmul(resid_0, weight) def calc_F(z_1): resid_1 = torch.matmul(z_1, weight.T) - x return 0.5 * resid_1.pow(2).sum() + alpha * z_1.abs().sum() def calc_Q(z_1, t): dz = z_1 - z return (fval_0 + (dz * fgrad_0).sum() + (0.5 / t) * dz.pow(2).sum() + alpha * z_1.abs().sum()) lr = lr0 z_next = None for i in range(maxiter): z_next = F.softshrink(z - lr * fgrad_0, alpha * lr) F_next = calc_F(z_next) Q_next = calc_Q(z_next, lr) if verbose: print('iter: %4d, t: %0.5f, F-Q: %0.5f' % (i, lr, F_next-Q_next)) if F_next <= Q_next: break lr = lr / eta else: warnings.warn('backtracking line search failed. Reverting to initial ' 'step size') lr = lr0 z_next = F.softshrink(z - lr * fgrad_0, alpha * lr) return z_next, lr
def ista(x, z0, weight, alpha=1.0, fast=True, lr='auto', maxiter=10, tol=1e-5, backtrack=False, eta_backtrack=1.5, verbose=False): if lr == 'auto': # set lr based on the maximum eigenvalue of W^T @ W; i.e. the # Lipschitz constant of \grad f(z), where f(z) = ||Wz - x||^2 L = _lipschitz_constant(weight) lr = 1 / L tol = z0.numel() * tol def loss_fn(z_k): resid = torch.matmul(z_k, weight.T) - x loss = 0.5 * resid.pow(2).sum() + alpha * z_k.abs().sum() return loss / x.size(0) def rss_grad(z_k): resid = torch.matmul(z_k, weight.T) - x return torch.matmul(resid, weight) # optimize z = z0 if fast: y, t = z0, 1 for _ in range(maxiter): if verbose: print('loss: %0.4f' % loss_fn(z)) # ista update z_prev = y if fast else z if backtrack: # perform backtracking line search z_next, _ = backtracking(z_prev, x, weight, alpha, lr, eta_backtrack) else: # constant step size z_next = F.softshrink(z_prev - lr * rss_grad(z_prev), alpha * lr) # check convergence if (z - z_next).abs().sum() <= tol: z = z_next break # update variables if fast: t_next = (1 + math.sqrt(1 + 4 * t**2)) / 2 y = z_next + ((t-1)/t_next) * (z_next - z) t = t_next z = z_next return z
def step(self, closure=None): """Performs a single optimization step (parameter update). Arguments: closure (callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. """ loss = super(TruncateAdam, self).step(closure) for pg in self.param_groups: lr_truncate = pg['lr_truncate'] g = pg['gravity'] K = pg['truncate_freq'] truncate_state = self.state['truncate_state'] if 'index' not in truncate_state: truncate_state['index'] = 1 truncate_index = truncate_state['index'] for p in pg['params']: if p.grad is None: continue # truncate step if truncate_index > 0 and truncate_index % K == 0: shrink_param = lr_truncate * K * g p.data.copy_(F.softshrink(p.data, lambd=shrink_param).data) truncate_state['index'] = 1 else: truncate_state['index'] += 1 # # =============== debug code start =============== # if 'global_step' not in self.state: # self.state['global_step'] = 0 # else: # self.state['global_step'] += 1 # if self.state['global_step'] % 50 == 0: # print('-'*5 + 'Enter TruncateAdam step') # print('index:{}, param.size:{}, lr:{}, g:{}, K:{}'.format( # truncate_index, p.size(), lr_truncate, g, K)) # num_nz_mask = len(p.data.nonzero()) # print('non-zero coefficient mask {}'.format(num_nz_mask)) # # =============== debug code end =============== return loss
def loss_fn(input, target): diff = target - input if shrink and shrink > 0: diff = F.softshrink(diff, shrink) sqdf = diff**2 if cap and cap > 0: abdf = diff.abs() sqdf = torch.where(abdf < cap, sqdf, cap * (2 * abdf - cap)) if reduction is None or reduction.lower() == "none": return sqdf elif reduction.lower() == "mean": return sqdf.mean() elif reduction.lower() == "sum": return sqdf.sum() elif reduction.lower() in ("batch", "bmean"): return sqdf.sum() / sqdf.shape[0] else: raise ValueError(f"{reduction} is not a valid reduction type")
def ista_forward(self, data, return_loss=False): """ Implements ISTA. """ if return_loss: loss = [] # Initializations. We_y = self.We(data) x = We_y.clone() # Iterations. for n in range(self.n_iter): x = F.softshrink(We_y + self.S(x), self.thresh) if return_loss: loss += [self.lossFcn(x, data)] # Fin. if return_loss: return x, loss return x
def __init__(self, input_size, num_basis, lam, solver_type='spams', solver_params=None, size_average_recon=False, size_average_l1=False): assert not size_average_recon assert not size_average_l1 super().__init__() # TODO: consistently change the semantic of size_average # or I always use size_average_* = False and handle this later on. self.lam = lam self.linear_module = Linear(num_basis, input_size, bias=False) self.cost = MSELoss(reduction='sum') self.size_average_l1 = size_average_l1 self.solver_type = solver_type self.solver_params = deepcopy(solver_params) # define the function for fista_custom. def f_fista(target: Tensor, code: Tensor, calc_grad): recon = self.linear_module(code) cost_recon = self.cost(recon, target).item() if not calc_grad: return cost_recon else: # compute grad. grad_this = linear(recon - target, self.linear_module.weight.t()) grad_this *= 2 return cost_recon, grad_this def g_fista(x: Tensor): cost = torch.abs(x) * self.lam if self.size_average_l1: cost = cost.mean() else: cost = cost.sum() return cost self.f_fista = f_fista self.g_fista = g_fista self.pl = lambda x, L: softshrink(x, lambd=self.lam / L).data if self.solver_type == 'fista_custom': self.L = 0.1 # init L
def salsa_forward(self, data, return_loss=False): """ Implements SALSA. """ if return_loss: loss = [] We_y = self.We(data) x = We_y d = torch.zeros(We_y.size()).to(self.device) for it in range(self.n_iter): u = F.softshrink(x + d, lambd=self.thresh) x = self.S(We_y + self.mu * (u - d)) d += x - u if return_loss: loss += [self.lossFcn(x, data)] # Fin. if return_loss: return x, loss return x
def test_softshrink(self): inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) output = F.softshrink(inp, lambd=0.5)
def __init__(self, lambd, beta, N_l, N_c): super(Line, self).__init__() self.weightx = nn.Parameter(torch.arange(-500, 510, 10, dtype=torch.float) / 255, requires_grad=False) l = lambd / max(beta, 1e-4) self.weighty = nn.Parameter(F.softshrink(self.weightx.view(1,1,-1).repeat(N_l, N_c, 1), l).contiguous())
def basis_pursuit_admm(A, b, lambd, M_inv=None, tol=1e-4, max_iters=100, return_stats=False): r""" Basis Pursuit solver for the :math:`Q_1^\epsilon` problem .. math:: \min_x \frac{1}{2} \left|\left| \boldsymbol{A}\vec{x} - \vec{b} \right|\right|_2^2 + \lambda \|x\|_1 via the alternating direction method of multipliers (ADMM). Parameters ---------- A : (N, M) torch.Tensor The input weight matrix :math:`\boldsymbol{A}`. b : (B, N) torch.Tensor The right side of the equation :math:`\boldsymbol{A}\vec{x} = \vec{b}`. lambd : float :math:`\lambda`, controls the sparsity of :math:`\vec{x}`. tol : float The accuracy tolerance of ADMM. max_iters : int Run for at most `max_iters` iterations. Returns ------- torch.Tensor (B, M) The solution vector batch :math:`\vec{x}`. """ A_dot_b = b.matmul(A) if M_inv is None: M = A.t().matmul(A) + torch.eye(A.shape[1], device=A.device) M_inv = M.inverse().t() del M batch_size = b.shape[0] v = torch.zeros(batch_size, A.shape[1], device=A.device) u = torch.zeros(batch_size, A.shape[1], device=A.device) v_prev = v.clone() v_solution = v.clone() solved = torch.zeros(batch_size, dtype=torch.bool) iter_id = 0 dv_norm = None for iter_id in range(max_iters): b_eff = A_dot_b + v - u x = b_eff.matmul(M_inv) # M_inv is already transposed # x is of shape (<=B, m_atoms) v = F.softshrink(x + u, lambd) u = u + x - v v_norm = v.norm(dim=1) if (v_norm == 0).any(): warnings.warn(f"Lambda ({lambd}) is set too large: " f"the output vector is zero-valued.") dv_norm = (v - v_prev).norm(dim=1) / (v_norm + 1e-9) solved_batch = dv_norm < tol v, u, A_dot_b = _reduce(solved, solved_batch, v_solution, v, u, A_dot_b) if v.shape[0] == 0: # all solved break v_prev = v.clone() if iter_id != max_iters - 1: assert solved.all() v_solution[~solved] = v # dump unsolved iterations if return_stats: return v_solution, dv_norm.mean(), iter_id return v_solution
def step(self, normalizer=None, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue if normalizer is None: grad = p.grad.data else: grad = p.grad.data.div(normalizer) if grad.is_sparse: raise RuntimeError( 'Adam does not support sparse gradients, please consider SparseAdam instead' ) amsgrad = group['amsgrad'] lamda = group['lamda'] normalize = group['normalize'] state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: max_exp_avg_sq = state['max_exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 if group['weight_decay'] != 0: grad.add_(group['weight_decay'], p.data) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = max_exp_avg_sq.sqrt().add_(group['eps']) else: denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1**state['step'] bias_correction2 = 1 - beta2**state['step'] step_size = group['lr'] * math.sqrt( bias_correction2) / bias_correction1 p.data.addcdiv_(-step_size, exp_avg, denom) if lamda != 0: p.data = softshrink(p.data, lamda) if normalize: p.data.div_(norm(p.data)) return loss
def train_epoch(args, epoch, model, data_loader, optimizer, writer): model.train() avg_loss = 0. if epoch == args.TSP_epoch and args.TSP: x = model.get_trajectory() x = x.detach().cpu().numpy() if not args.KMEANS: x = tsp_solver(x) else: x = kmeans_tsp(x) v, a = get_vel_acc(x) writer.add_figure('TSP_Trajectory_Proj_XY', plot_trajectory(x, 0, 1), epoch) writer.add_figure('TSP_Trajectory_Proj_YZ', plot_trajectory(x, 1, 2), epoch) writer.add_figure('TSP_Trajectory_Proj_XZ', plot_trajectory(x, 0, 2), epoch) writer.add_figure('TSP_Trajectory_3D', plot_trajectory(x, d3=True), epoch) writer.add_figure('TSP_Acc', plot_acc(a, args.a_max), epoch) writer.add_figure('TSP_Vel', plot_acc(v, args.v_max), epoch) np.save('trajTSP', x) with torch.no_grad(): model.subsampling.x.data = torch.tensor(x, device='cuda') args.a_max *= 2 args.v_max *= 2 if args.TSP and epoch > args.TSP_epoch and epoch <= args.TSP_epoch * 2: v0 = args.gamma * args.G_max * args.FOV * args.dt a0 = args.gamma * args.S_max * args.FOV * args.dt**2 * 1e3 args.a_max -= a0 / args.TSP_epoch args.v_max -= v0 / args.TSP_epoch if args.TSP and epoch == args.TSP_epoch * 2: args.vel_weight *= 10 args.acc_weight *= 10 if args.TSP and epoch == args.TSP_epoch * 2 + 10: args.vel_weight *= 10 args.acc_weight *= 10 if args.TSP and epoch == args.TSP_epoch * 2 + 20: args.vel_weight *= 10 args.acc_weight *= 10 start_epoch = start_iter = time.perf_counter() for iter, data in enumerate(data_loader): input, target, mean, std, norm = data input = input.unsqueeze(1).to(args.device) target = target.to(args.device) output = model(input).squeeze(1) x = model.get_trajectory() v, a = get_vel_acc(x) acc_loss = args.acc_weight * torch.sqrt( torch.sum(torch.pow(F.softshrink(a, args.a_max), 2))) vel_loss = args.vel_weight * torch.sqrt( torch.sum(torch.pow(F.softshrink(v, args.v_max), 2))) rec_loss = args.rec_weight * F.l1_loss(output, target) if args.TSP and epoch < args.TSP_epoch: loss = rec_loss else: loss = rec_loss + vel_loss + acc_loss optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = 0.99 * avg_loss + 0.01 * loss.item( ) if iter > 0 else loss.item() #writer.add_scalar('TrainLoss', loss.item(), iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} ' f'rec_loss: {rec_loss:.4g}, vel_loss: {vel_loss:.4g}, acc_loss: {acc_loss:.4g}, ' ) start_iter = time.perf_counter() return avg_loss, time.perf_counter() - start_epoch
def train_epoch(args, epoch, model, data_loader, optimizer, writer): model.train() avg_loss = 0. if epoch == args.TSP_epoch and args.TSP: x = model.module.get_trajectory() x = x.detach().cpu().numpy() for shot in range(x.shape[0]): x[shot, :, :] = tsp_solver(x[shot, :, :]) v, a = get_vel_acc(x) writer.add_figure('TSP_Trajectory', plot_trajectory(x), epoch) writer.add_figure('TSP_Acc', plot_acc(a, args.a_max), epoch) writer.add_figure('TSP_Vel', plot_acc(v, args.v_max), epoch) np.save('trajTSP',x) with torch.no_grad(): model.module.subsampling.x.data = torch.tensor(x, device='cuda') args.a_max *= 2 args.v_max *= 2 args.vel_weight = 1e-3 args.acc_weight = 1e-3 # if epoch == 30: # v0 = args.gamma * args.G_max * args.FOV * args.dt # a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3 # args.a_max = a0 *1.5 # args.v_max = v0 *1.5 # if args.TSP and epoch > args.TSP_epoch and epoch<=args.TSP_epoch*2: # v0 = args.gamma * args.G_max * args.FOV * args.dt # a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3 # args.a_max -= a0/args.TSP_epoch # args.v_max -= v0/args.TSP_epoch # # if args.TSP and epoch==args.TSP_epoch*2: # v0 = args.gamma * args.G_max * args.FOV * args.dt # a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3 # args.a_max = a0 # args.v_max = v0 # args.vel_weight *= 10 # args.acc_weight *= 10 # if args.TSP and epoch==args.TSP_epoch*2+10: # args.vel_weight *= 10 # args.acc_weight *= 10 # if args.TSP and epoch==args.TSP_epoch*2+20: # args.vel_weight *= 10 # args.acc_weight *= 10 if args.TSP: if epoch < args.TSP_epoch: model.module.subsampling.interp_gap = 1 elif epoch < 10 + args.TSP_epoch: model.module.subsampling.interp_gap = 10 v0 = args.gamma * args.G_max * args.FOV * args.dt a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3 args.a_max -= a0/args.TSP_epoch args.v_max -= v0/args.TSP_epoch elif epoch == 10 + args.TSP_epoch: model.module.subsampling.interp_gap = 10 v0 = args.gamma * args.G_max * args.FOV * args.dt a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3 args.a_max -= a0 / args.TSP_epoch args.v_max -= v0 / args.TSP_epoch elif epoch == 15 + args.TSP_epoch: model.module.subsampling.interp_gap = 10 v0 = args.gamma * args.G_max * args.FOV * args.dt a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3 args.a_max -= a0 / args.TSP_epoch args.v_max -= v0 / args.TSP_epoch elif epoch == 20 + args.TSP_epoch: model.module.subsampling.interp_gap = 10 args.vel_weight *= 10 args.acc_weight *= 10 elif epoch == 23 + args.TSP_epoch: model.module.subsampling.interp_gap = 5 args.vel_weight *= 10 args.acc_weight *= 10 elif epoch == 25 + args.TSP_epoch: model.module.subsampling.interp_gap = 1 args.vel_weight *= 10 args.acc_weight *= 10 else: if epoch < 10: model.module.subsampling.interp_gap = 50 elif epoch == 10: model.module.subsampling.interp_gap = 30 elif epoch == 15: model.module.subsampling.interp_gap = 20 elif epoch == 20: model.module.subsampling.interp_gap = 10 elif epoch == 23: model.module.subsampling.interp_gap = 5 elif epoch == 25: model.module.subsampling.interp_gap = 1 start_epoch = start_iter = time.perf_counter() print(f'a_max={args.a_max}, v_max={args.v_max}') for iter, data in enumerate(data_loader): optimizer.zero_grad() input, target, mean, std, norm = data input = input.to(args.device) target = target.to(args.device) output = model(input) # output = transforms.complex_abs(output) # complex to real # output = transforms.root_sum_of_squares(output, dim=1) output=output.squeeze() x = model.module.get_trajectory() v, a = get_vel_acc(x) acc_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(a, args.a_max).abs()+1e-8, 2))) vel_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(v, args.v_max).abs()+1e-8, 2))) rec_loss = F.l1_loss(output, target) if args.TSP and epoch < args.TSP_epoch: loss = args.rec_weight * rec_loss else: loss = args.rec_weight * rec_loss + args.vel_weight * vel_loss + args.acc_weight * acc_loss loss.backward() optimizer.step() avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item() # writer.add_scalar('TrainLoss', loss.item(), global_step + iter) if iter % args.report_interval == 0: logging.info( f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' f'Iter = [{iter:4d}/{len(data_loader):4d}] ' f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} ' f'rec_loss: {rec_loss:.4g}, vel_loss: {vel_loss:.4g}, acc_loss: {acc_loss:.4g}' ) start_iter = time.perf_counter() return avg_loss, time.perf_counter() - start_epoch
''' tanh/ hardtanh/ softsign ''' x = torch.arange(-5,5,0.1).view(-1,1) y = F.tanh(x) pic(x, y, (2,4,3), 'tanh') y = F.hardtanh(x) pic(x, y, (2,4,3), 'hardtanh') y = F.softsign(x) pic(x, y, (2,4,3), 'softsign') ''' hardshrink/ tanhshrink/ tanhshrink ''' x = torch.arange(-3,3,0.1).view(-1,1) y = F.hardshrink(x) pic(x, y, (2,4,4), 'hardshrink') y = F.tanhshrink(x) pic(x, y, (2,4,4), 'tanhshrink') y = F.softshrink(x) pic(x, y, (2,4,4), 'tanhshrink') ''' sigmoid ''' x = torch.arange(-5,5,0.1).view(-1,1) y = F.sigmoid(x) pic(x, y, (2,4,5), 'sigmoid') ''' relu6 ''' x = torch.arange(-5,10,0.1).view(-1,1) y = F.relu6(x) pic(x, y, (2,4,6), 'relu6') ''' logsigmoid ''' x = torch.arange(-3,3,0.1).view(-1,1) y = F.logsigmoid(x)
def coord_descent_mod(x, W, z0=None, alpha=1.0, max_iter=1000, tol=1e-4): """Modified variant of the CD algorithm Based on `enet_coordinate_descent` from sklearn.linear_model._cd_fast This version is much slower, but it produces more reliable results as compared to the above. x : Tensor of shape [n_samples, n_features] W : Tensor of shape [n_features, n_components] z : Tensor of shape [n_samples, n_components] """ n_features, n_components = W.shape n_samples = x.shape[0] assert x.shape[1] == n_features if z0 is None: z = x.new_zeros(n_features, n_components) # [N,K] else: assert z0.shape == (n_features, n_components) z = z0 gap = z.new_full((n_samples,), tol + 1.) converged = z.new_zeros(n_samples, dtype=torch.bool) d_w_tol = tol tol = tol * x.pow(2).sum(1) # [N,] # compute squared norms of the columns of X norm_cols_X = W.pow(2).sum(0) # [K,] # function to check convergence state (per sample) def _check_convergence(z_, x_, R_, tol_): XtA = torch.mm(R_, W) # [N,K] dual_norm_XtA = XtA.abs().max(1)[0] # [N,] R_norm2 = R_.pow(2).sum(1) # [N,] small_norm = dual_norm_XtA <= alpha const = (alpha / dual_norm_XtA).masked_fill(small_norm, 1.) gap = torch.where(small_norm, R_norm2, 0.5 * R_norm2 * (1 + const.pow(2))) gap = gap + alpha * z_.abs().sum(1) - const * (R_ * x_).sum(1) converged = gap < tol_ return converged, gap # initialize residual R = x - torch.matmul(z, W.T) # [N,D] for n_iter in range(max_iter): if converged.all(): break active_ix, = torch.where(~converged) z_max = z.new_zeros(len(active_ix)) d_z_max = z.new_zeros(len(active_ix)) for i in range(n_components): # Loop over components if norm_cols_X[i] == 0: continue atom_i = W[:,i].contiguous() z_i = z[active_ix, i].clone() nonzero = z_i != 0 R[active_ix[nonzero]] += torch.outer(z_i[nonzero], atom_i) z[active_ix, i] = F.softshrink(R[active_ix].matmul(atom_i), alpha) z[active_ix, i] /= norm_cols_X[i] z_new_i = z[active_ix, i] nonzero = z_new_i != 0 R[active_ix[nonzero]] -= torch.outer(z_new_i[nonzero], atom_i) # update the maximum absolute coefficient update d_z_max = torch.maximum(d_z_max, (z_new_i - z_i).abs()) z_max = torch.maximum(z_max, z_new_i.abs()) ### check convergence ### check = (z_max == 0) | (d_z_max / z_max < d_w_tol) | (n_iter == max_iter-1) if not check.any(): continue check_ix = active_ix[check] converged[check_ix], gap[check_ix] = \ _check_convergence(z[check_ix], x[check_ix], R[check_ix], tol[check_ix]) return z, gap
def softshrink(self, x, lambd=0.5): return F.softshrink(x, lambd)