Beispiel #1
0
    def forward(self, x):
        residual = x
        y1 = residual

        if self.downsample is not None:
            y1 = self.downsample(residual)

        y2 = F.conv2d(residual, self.weight1)
        y2 = self.bn1(y2)
        y2 = F.softshrink(y2, self.tau)

        w = self.weight2.repeat(self.out_chnls, 1, 1, 1)
        y2 = F.conv2d(y2,
                      w,
                      stride=self.stride,
                      padding=self.padding,
                      groups=self.out_chnls)
        y2 = self.bn2(y2)
        y2 = self.activation(y2)

        y2 = self.conv3(y2)
        y2 = self.bn3(y2)

        y = y1 + y2
        y = F.softshrink(y, self.tau)

        return y
Beispiel #2
0
def encode(model, x):
    n = x.shape[1]
    N = x.shape[0]
    x = x.t()

    tol = 1e-10
    D = model.D.detach()
    lambd_eta = lambd = model.eta * model.lambd
    Dtx = torch.matmul(D.t(), x).detach()
    z_ = F.softshrink(model.eta * Dtx, lambd_eta)
    for i in range(model.T):
        res = torch.matmul(D, z_) - x
        Dtres = torch.matmul(D.t(), res).detach()
        z = F.softshrink(z_ - model.eta * Dtres, lambd=model.lambd * model.eta)
        if torch.norm(z - z_) < tol:
            break
        else:
            z_ = z

    # computing closed form solution given sign pattern
    alfa = z
    SIGNS = torch.sign(z)
    for i in range(N):
        S = (z[:, i] != 0)
        if np.sum(S.numpy()) == 0:
            break
        else:
            Gs = torch.matmul(model.D[:, S].t(), model.D[:, S])
            b = (torch.matmul(model.D[:, S].t(), x[:, i]) -
                 model.lambd * SIGNS[S, i])
            alfa[S, i] = torch.matmul(torch.inverse(Gs), b)

    return alfa
Beispiel #3
0
def coord_descent(x, W, z0=None, alpha=1.0, maxiter=1000, tol=1e-6, verbose=False):
    input_dim, code_dim = W.shape  # [D,K]
    batch_size, input_dim1 = x.shape  # [N,D]
    assert input_dim1 == input_dim
    tol = tol * code_dim
    if z0 is None:
        z = x.new_zeros(batch_size, code_dim)  # [N,K]
    else:
        assert z0.shape == (batch_size, code_dim)
        z = z0

    # initialize b
    # TODO: how should we initialize b when 'z0' is provided?
    b = torch.mm(x, W)  # [N,K]

    # precompute S = I - W^T @ W
    S = - torch.mm(W.T, W)  # [K,K]
    S.diagonal().add_(1.)

    # loss function
    def fn(z):
        x_hat = torch.matmul(z, W.T)
        loss = 0.5 * (x_hat - x).pow(2).sum() + alpha * z.abs().sum()
        return loss

    # update function
    def cd_update(z, b):
        z_next = F.softshrink(b, alpha)  # [N,K]
        z_diff = z_next - z  # [N,K]
        k = z_diff.abs().argmax(1)  # [N]
        kk = k.unsqueeze(1)  # [N,1]
        b = b + S[:,k].T * z_diff.gather(1, kk)  # [N,K] += [N,K] * [N,1]
        z = z.scatter(1, kk, z_next.gather(1, kk))
        return z, b

    active = torch.arange(batch_size, device=W.device)
    for i in range(maxiter):
        if len(active) == 0:
            break
        z_old = z[active]
        z_new, b[active] = cd_update(z_old, b[active])
        update = (z_new - z_old).abs().sum(1)
        z[active] = z_new
        active = active[update > tol]
        if verbose:
            print('iter %i - loss: %0.4f' % (i, fn(F.softshrink(b, alpha))))

    z = F.softshrink(b, alpha)

    return z
Beispiel #4
0
 def fista_forward(self, data, return_loss=False):
     """
 Implements FISTA.
 """
     if return_loss:
         loss = []
     # Initializations.
     yk = self.We(data)
     xprev = torch.zeros(yk.size()).to(self.device)
     t = 1
     # Iterations.
     for it in range(self.n_iter):
         # Update logistics.
         residual = F.linear(yk, self.We.weight.t()) - data
         # ISTA step
         tmp = yk - self.We(residual) / self.L
         xk = F.softshrink(tmp, lambd=self.thresh)
         # FISTA stepsize update:
         tnext = (1 + (1 + 4 * (t**2))**.5) / 2
         fact = (t - 1) / tnext
         # Use momentum to update code estimate.
         yk = xk + (xk - xprev) * fact
         # Update "prev" stuff.
         xprev = xk
         t = tnext
         if return_loss:
             loss += [self.lossFcn(yk, data)]
     # Fin.
     if return_loss:
         return yk, loss
     return yk
Beispiel #5
0
 def forward(self, input, k, d):
     hks = k.shape[-1] // 2
     hds = d.shape[-1] // 2
     x_padding = (hks, hks, hks, hks)
     r_padding = (hds, hds, hds, hds)
     output = []
     for c in range(input.size(1)):
         y = input[:, c].unsqueeze(1)
         x = y.clone()
         for i in range(self.n_iter):
             # z update
             z = F.conv2d(F.pad(x, (2, 2, 2, 2), 'replicate'), self.weight)
             z = F.softshrink(z,
                              self.lambd / max(1e-4, self.beta[i].item()))
             # x update
             for j in range(self.n_in):
                 r0 = y - F.conv2d(F.pad(x, x_padding, 'replicate'), k)
                 r1 = z - F.conv2d(F.pad(x, (2, 2, 2, 2), 'replicate'),
                                   self.weight)
                 r = torch.cat([r0, r1], dim=1)
                 r_pad = F.pad(r, r_padding, 'replicate')
                 for l in range(3):
                     x = x + F.conv2d(r_pad[:, l].unsqueeze(0),
                                      d[i, l].unsqueeze(0).unsqueeze(0))
             x = x.clamp(0, 1)
         output.append(x.clone())
     output = torch.cat(output, 1)
     return output
Beispiel #6
0
def resample(image: torch.Tensor, displacement, image_is_displacement=False):
    shape = displacement[0].shape

    grid = torch.meshgrid(
        torch.linspace(-1, 2 * shape[0] / image.shape[0] - 1,
                       shape[0]).float().cuda(),
        torch.linspace(-1, 2 * shape[1] / image.shape[1] - 1,
                       shape[1]).float().cuda())

    df = torch.stack([
        torch.add(grid[1 - i], 2 / image.shape[1 - i], displacement[i])
        for i in range(2)
    ], 2)[None, ]

    image = image[None, None, ]

    if not image_is_displacement:
        out = F.grid_sample(image, df, padding_mode='zeros')
    else:
        out = F.grid_sample(image, df, padding_mode='border')

        dx = torch.mean((image[:, :, :, -1] - image[:, :, :, 0]), 2) / 2
        dy = torch.mean((image[:, :, -1, :] - image[:, :, 0, :]), 2) / 2
        pad_offset = F.softshrink(df, 1) * torch.stack(
            (dx, dy), 2)[:, :, None, None, :]
        pad_offset = torch.sum(pad_offset, 4)
        out += pad_offset
    return out[0][0]
Beispiel #7
0
 def cd_update(z, b):
     z_next = F.softshrink(b, alpha)  # [N,K]
     z_diff = z_next - z  # [N,K]
     k = z_diff.abs().argmax(1)  # [N]
     kk = k.unsqueeze(1)  # [N,1]
     b = b + S[:,k].T * z_diff.gather(1, kk)  # [N,K] += [N,K] * [N,1]
     z = z.scatter(1, kk, z_next.gather(1, kk))
     return z, b
Beispiel #8
0
def evaluate(args, epoch, model, data_loader, writer):
    model.eval()
    losses = []
    start = time.perf_counter()
    with torch.no_grad():
        if epoch != 0:
            for iter, data in enumerate(data_loader):
                input, target, mean, std, norm = data
                input = input.to(args.device)
                target = target.to(args.device)

                output = model(input)
                # output = transforms.complex_abs(output)  # complex to real
                # output = transforms.root_sum_of_squares(output, dim=1)
                output = output.squeeze()

                loss = F.l1_loss(output, target)
                losses.append(loss.item())

            x = model.module.get_trajectory()
            v, a = get_vel_acc(x)
            acc_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(a, args.a_max), 2)))
            vel_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(v, args.v_max), 2)))
            rec_loss = np.mean(losses)

            writer.add_scalar('Rec_Loss', rec_loss, epoch)
            writer.add_scalar('Acc_Loss', acc_loss.detach().cpu().numpy(), epoch)
            writer.add_scalar('Vel_Loss', vel_loss.detach().cpu().numpy(), epoch)
            writer.add_scalar('Total_Loss',
                              rec_loss + acc_loss.detach().cpu().numpy() + vel_loss.detach().cpu().numpy(), epoch)

        x = model.module.get_trajectory()
        v, a = get_vel_acc(x)
        if args.TSP and epoch < args.TSP_epoch:
            writer.add_figure('Scatter', plot_scatter(x.detach().cpu().numpy()), epoch)
        else:
            writer.add_figure('Trajectory', plot_trajectory(x.detach().cpu().numpy()), epoch)
            writer.add_figure('Scatter', plot_scatter(x.detach().cpu().numpy()), epoch)
        writer.add_figure('Accelerations_plot', plot_acc(a.cpu().numpy(), args.a_max), epoch)
        writer.add_figure('Velocity_plot', plot_acc(v.cpu().numpy(), args.v_max), epoch)
        writer.add_text('Coordinates', str(x.detach().cpu().numpy()).replace(' ', ','), epoch)
    if epoch == 0:
        return None, time.perf_counter() - start
    else:
        return np.mean(losses), time.perf_counter() - start
Beispiel #9
0
    def __init__(self, num_basis, kernel_size, lam,
                 solver_type='fista_custom', solver_params=None,
                 size_average_recon=False,
                 size_average_l1=False, im_size=None, legacy_bias=False):
        assert not size_average_recon
        assert not size_average_l1
        super().__init__()

        # TODO: consistently change the semantic of size_average
        # or I always use size_average_* = False and handle this later on.
        self.lam = lam
        self.linear_module = ConvTranspose2d(num_basis, 1, kernel_size, bias=legacy_bias)
        # the official implement has bias.
        if legacy_bias:
            self.linear_module.bias.data.zero_()
        # TODO handle weighted version.
        self.cost = MSELoss(reduction='sum')
        self.size_average_l1 = size_average_l1
        self.solver_type = solver_type
        self.solver_params = deepcopy(solver_params)
        # save a template for later use.
        assert im_size is not None
        self.register_buffer('_template_weight', generate_weight_template(im_size,
                                                                          im_size,
                                                                          kernel_size))

        # self.register_buffer('_template_weight', Variable(Tensor(np.ones((25, 25)))))

        # define the function for fista_custom.
        def f_fista(target: Tensor, code: Tensor, calc_grad):
            with torch.no_grad():
                recon = self.linear_module(code)
                cost_recon = 0.5 * self.cost(recon, self._template_weight * target).item()
                if not calc_grad:
                    return cost_recon
                else:
                    # compute grad.
                    grad_this = conv2d(recon - self._template_weight * target,
                                       self.linear_module.weight, None)
                    return cost_recon, grad_this

        def g_fista(x: Tensor):
            cost = torch.abs(x) * self.lam
            if self.size_average_l1:
                cost = cost.mean()
            else:
                cost = cost.sum()
            return cost

        self.f_fista = f_fista
        self.g_fista = g_fista
        self.pl = lambda x, L: softshrink(x, lambd=self.lam / L).data

        if self.solver_type == 'fista_custom':
            self.L = 0.1  # init L
Beispiel #10
0
def evaluate(args, epoch, model, data_loader, writer):
    model.eval()
    losses = []
    psnrs= []
    start = time.perf_counter()
    with torch.no_grad():
        if epoch != 0:
            for iter, data in enumerate(data_loader):
                input, target, mean, std, norm = data
                input = input.unsqueeze(1).to(args.device)
                target = target.to(args.device)
                output = model(input).squeeze(1)
                outputnorm,_,_=transforms.normalize_instance(output, eps=1e-11)
                psnrs.append( psnr(target.cpu().numpy(), outputnorm.cpu().numpy()))
                loss = args.rec_weight * F.l1_loss(output, target)
                losses.append(loss.item())

            x = model.get_trajectory()
            v, a = get_vel_acc(x)
            acc_loss = args.acc_weight * torch.sqrt(torch.sum(torch.pow(F.softshrink(a, args.a_max), 2)))
            vel_loss = args.vel_weight * torch.sqrt(torch.sum(torch.pow(F.softshrink(v, args.v_max), 2)))
            rec_loss =  np.mean(losses)
            psnr_avg = np.mean(psnrs)

            writer.add_scalar('Rec_Loss', rec_loss, epoch)
            writer.add_scalar('Acc_Loss', acc_loss.detach().cpu().numpy()/args.acc_weight, epoch)
            writer.add_scalar('Vel_Loss', vel_loss.detach().cpu().numpy()/args.vel_weight, epoch)
            writer.add_scalar('Acc_Weight',args.acc_weight,epoch)
            writer.add_scalar('Total_Loss',
                              rec_loss + acc_loss.detach().cpu().numpy() + vel_loss.detach().cpu().numpy(), epoch)
            writer.add_scalar('PSNR', psnr_avg, epoch)

        x = model.get_trajectory()
        v, a = get_vel_acc(x)
        writer.add_figure('Trajectory_Proj_XY', plot_trajectory(x.detach().cpu().numpy(),0,1), epoch)
        writer.add_figure('Trajectory_Proj_YZ', plot_trajectory(x.detach().cpu().numpy(),1,2), epoch)
        writer.add_figure('Trajectory_Proj_XZ', plot_trajectory(x.detach().cpu().numpy(),0,2), epoch)
        writer.add_figure('Trajectory_3D', plot_trajectory(x.detach().cpu().numpy(),d3=True), epoch)
        writer.add_figure('Accelerations_plot', plot_acc(a.cpu().numpy(), args.a_max), epoch)
        writer.add_figure('Velocity_plot', plot_acc(v.cpu().numpy(), args.v_max), epoch)
        writer.add_text('Coordinates', str(x.cpu().numpy()).replace(' ', ','), epoch)
    return np.mean(losses), time.perf_counter() - start
Beispiel #11
0
    def step(self, normalizer=None, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lamda = group['lamda']
            normalize = group['normalize']

            for p in group['params']:
                if p.grad is None:
                    continue

                if normalizer is None:
                    d_p = p.grad.data
                else:
                    d_p = p.grad.data.div(normalizer)

                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state[
                            'momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

                if lamda != 0:
                    p.data = softshrink(p.data, group['lr'] * lamda)

                if normalize:
                    p.data.div_(norm(p.data))

        return loss
Beispiel #12
0
    def step(self, closure=None):
        """Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for pg in self.param_groups:
            lr = pg['lr']
            g = pg['gravity']
            K = pg['truncate_freq']
            weight_decay = pg['weight_decay']

            truncate_state = self.state['truncate_state']
            if 'index' not in truncate_state:
                truncate_state['index'] = 1
            truncate_index = truncate_state['index']

            for p in pg['params']:
                if p.grad is None:
                    continue

                # # =============== debug code start ===============
                # if 'global_step' not in self.state:
                #     self.state['global_step'] = 0
                # else:
                #     self.state['global_step'] += 1
                # if self.state['global_step'] % 20 == 0:
                #     print('-'*5 + 'Enter TruncateSGD step')
                #     print('index:{}, param.size:{}, lr:{}, g:{}, K:{}, weight_decay:{}'.format(
                #         truncate_index, p.size(), lr, g, K, weight_decay))
                #     num_nz_mask = len(p.data.nonzero())
                #     print('non-zero coefficient mask {}'.format(num_nz_mask))
                # # =============== debug code end ===============

                # gradient step
                p_grad = p.grad.data
                if weight_decay != 0:
                    p_grad.add_(weight_decay, p.data)
                p.data.add_(-lr, p_grad)

                # truncate step
                if truncate_index > 0 and truncate_index % K == 0:
                    shrink_param = lr * K * g
                    p.data.copy_(F.softshrink(p.data, lambd=shrink_param).data)
                    truncate_state['index'] = 1
                else:
                    truncate_state['index'] += 1
        return loss
Beispiel #13
0
def train_epoch(args, epoch, model, data_loader, optimizer, writer):
    model.train()
    avg_loss = 0.

    if epoch>=args.weight_increase_epoch:
        args.vel_weight *= 1.5
        args.acc_weight *= 1.5
    start_epoch = start_iter = time.perf_counter()
    for iter, data in enumerate(data_loader):
        input, target, mean, std, norm = data
        input = input.unsqueeze(1).to(args.device)
        target = target.to(args.device)
        output = model(input).squeeze(1)
        x = model.get_trajectory()
        v, a = get_vel_acc(x)

        #During training, adjust the acc,vel for the coarse trajectory.
        acc_loss = args.acc_weight * torch.sqrt(torch.sum(
            torch.pow(F.softshrink(a, args.a_max * (args.realworld_points_per_shot / args.points_per_shot) ** 2 / 3),
                      2)))
        vel_loss = args.vel_weight * torch.sqrt(torch.sum(
            torch.pow(F.softshrink(v, args.v_max * args.realworld_points_per_shot / args.points_per_shot / 2), 2)))
        rec_loss = args.rec_weight * F.l1_loss(output, target)

        loss = rec_loss + vel_loss + acc_loss
        optimizer.zero_grad()
        loss.backward()
        if args.initialization == '2dstackofstars':
            x.grad[:,:,2]=0
        optimizer.step()
        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        if iter % args.report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                f'rec_loss: {rec_loss:.4g}, vel_loss: {vel_loss:.4g}, acc_loss: {acc_loss:.4g}, '
            )
        start_iter = time.perf_counter()
    return avg_loss, time.perf_counter() - start_epoch
    def __init__(self, **kwargs):

        super().__init__(**kwargs)

        self.init_zero_knot_indexes()
        self.init_derivative_filters()

        # Initializing the spline activation functions

        # tensor with locations of spline coefficients
        grid_tensor = self.grid_tensor  # size: (num_activations, size)
        coefficients = torch.zeros_like(grid_tensor)  # spline coefficients

        # The coefficients are initialized with the value of the activation
        # at each knot (c[k] = f[k], since B1 splines are interpolators).
        if self.init == 'even_odd':
            # initalize half of the activations with an even function (abs) and
            # and the other half with an odd function (soft threshold).
            half = self.num_activations // 2
            coefficients[0:half, :] = (grid_tensor[0:half, :]).abs()
            coefficients[half::, :] = F.softshrink(grid_tensor[half::, :],
                                                   lambd=0.5)

        elif self.init == 'relu':
            coefficients = F.relu(grid_tensor)

        elif self.init == 'leaky_relu':
            coefficients = F.leaky_relu(grid_tensor, negative_slope=0.01)

        elif self.init == 'softplus':
            coefficients = F.softplus(grid_tensor, beta=3, threshold=10)

        elif self.init == 'random':
            coefficients.normal_()

        elif self.init == 'identity':
            coefficients = grid_tensor.clone()

        elif self.init != 'zero':
            raise ValueError(
                'init should be even_odd/relu/leaky_relu/softplus/'
                'random/identity/zero]')

        # Need to vectorize coefficients to perform specific operations
        self.coefficients_vect = nn.Parameter(
            coefficients.contiguous().view(-1))  # size: (num_activations*size)

        # Create the finite-difference matrix
        self.init_D()

        # Flag to keep track of sparsification process
        self.sparsification_flag = False
Beispiel #15
0
def backtracking(z, x, weight, alpha, lr0, eta=1.5, maxiter=1000, verbose=False):
    if eta <= 1:
        raise ValueError('eta must be > 1.')

    # store initial values
    resid_0 = torch.matmul(z, weight.T) - x
    fval_0 = 0.5 * resid_0.pow(2).sum()
    fgrad_0 = torch.matmul(resid_0, weight)

    def calc_F(z_1):
        resid_1 = torch.matmul(z_1, weight.T) - x
        return 0.5 * resid_1.pow(2).sum() + alpha * z_1.abs().sum()

    def calc_Q(z_1, t):
        dz = z_1 - z
        return (fval_0
                + (dz * fgrad_0).sum()
                + (0.5 / t) * dz.pow(2).sum()
                + alpha * z_1.abs().sum())

    lr = lr0
    z_next = None
    for i in range(maxiter):
        z_next = F.softshrink(z - lr * fgrad_0, alpha * lr)
        F_next = calc_F(z_next)
        Q_next = calc_Q(z_next, lr)
        if verbose:
            print('iter: %4d,  t: %0.5f,  F-Q: %0.5f' % (i, lr, F_next-Q_next))
        if F_next <= Q_next:
            break
        lr = lr / eta
    else:
        warnings.warn('backtracking line search failed. Reverting to initial '
                      'step size')
        lr = lr0
        z_next = F.softshrink(z - lr * fgrad_0, alpha * lr)

    return z_next, lr
Beispiel #16
0
def ista(x, z0, weight, alpha=1.0, fast=True, lr='auto', maxiter=10,
         tol=1e-5, backtrack=False, eta_backtrack=1.5, verbose=False):
    if lr == 'auto':
        # set lr based on the maximum eigenvalue of W^T @ W; i.e. the
        # Lipschitz constant of \grad f(z), where f(z) = ||Wz - x||^2
        L = _lipschitz_constant(weight)
        lr = 1 / L
    tol = z0.numel() * tol

    def loss_fn(z_k):
        resid = torch.matmul(z_k, weight.T) - x
        loss = 0.5 * resid.pow(2).sum() + alpha * z_k.abs().sum()
        return loss / x.size(0)

    def rss_grad(z_k):
        resid = torch.matmul(z_k, weight.T) - x
        return torch.matmul(resid, weight)

    # optimize
    z = z0
    if fast:
        y, t = z0, 1
    for _ in range(maxiter):
        if verbose:
            print('loss: %0.4f' % loss_fn(z))

        # ista update
        z_prev = y if fast else z
        if backtrack:
            # perform backtracking line search
            z_next, _ = backtracking(z_prev, x, weight, alpha, lr, eta_backtrack)
        else:
            # constant step size
            z_next = F.softshrink(z_prev - lr * rss_grad(z_prev), alpha * lr)

        # check convergence
        if (z - z_next).abs().sum() <= tol:
            z = z_next
            break

        # update variables
        if fast:
            t_next = (1 + math.sqrt(1 + 4 * t**2)) / 2
            y = z_next + ((t-1)/t_next) * (z_next - z)
            t = t_next
        z = z_next

    return z
Beispiel #17
0
    def step(self, closure=None):
        """Performs a single optimization step (parameter update).

        Arguments:
            closure (callable): A closure that reevaluates the model and
                returns the loss. Optional for most optimizers.
        """
        loss = super(TruncateAdam, self).step(closure)

        for pg in self.param_groups:
            lr_truncate = pg['lr_truncate']
            g = pg['gravity']
            K = pg['truncate_freq']

            truncate_state = self.state['truncate_state']
            if 'index' not in truncate_state:
                truncate_state['index'] = 1
            truncate_index = truncate_state['index']

            for p in pg['params']:
                if p.grad is None:
                    continue

                # truncate step
                if truncate_index > 0 and truncate_index % K == 0:
                    shrink_param = lr_truncate * K * g
                    p.data.copy_(F.softshrink(p.data, lambd=shrink_param).data)
                    truncate_state['index'] = 1
                else:
                    truncate_state['index'] += 1

                # # =============== debug code start ===============
                # if 'global_step' not in self.state:
                #     self.state['global_step'] = 0
                # else:
                #     self.state['global_step'] += 1
                # if self.state['global_step'] % 50 == 0:
                #     print('-'*5 + 'Enter TruncateAdam step')
                #     print('index:{}, param.size:{}, lr:{}, g:{}, K:{}'.format(
                #         truncate_index, p.size(), lr_truncate, g, K))
                #     num_nz_mask = len(p.data.nonzero())
                #     print('non-zero coefficient mask {}'.format(num_nz_mask))
                # # =============== debug code end ===============

        return loss
Beispiel #18
0
 def loss_fn(input, target):
     diff = target - input
     if shrink and shrink > 0:
         diff = F.softshrink(diff, shrink)
     sqdf = diff**2
     if cap and cap > 0:
         abdf = diff.abs()
         sqdf = torch.where(abdf < cap, sqdf, cap * (2 * abdf - cap))
     if reduction is None or reduction.lower() == "none":
         return sqdf
     elif reduction.lower() == "mean":
         return sqdf.mean()
     elif reduction.lower() == "sum":
         return sqdf.sum()
     elif reduction.lower() in ("batch", "bmean"):
         return sqdf.sum() / sqdf.shape[0]
     else:
         raise ValueError(f"{reduction} is not a valid reduction type")
Beispiel #19
0
 def ista_forward(self, data, return_loss=False):
     """
 Implements ISTA.
 """
     if return_loss:
         loss = []
     # Initializations.
     We_y = self.We(data)
     x = We_y.clone()
     # Iterations.
     for n in range(self.n_iter):
         x = F.softshrink(We_y + self.S(x), self.thresh)
         if return_loss:
             loss += [self.lossFcn(x, data)]
     # Fin.
     if return_loss:
         return x, loss
     return x
Beispiel #20
0
    def __init__(self, input_size, num_basis, lam,
                 solver_type='spams', solver_params=None,
                 size_average_recon=False,
                 size_average_l1=False):
        assert not size_average_recon
        assert not size_average_l1
        super().__init__()

        # TODO: consistently change the semantic of size_average
        # or I always use size_average_* = False and handle this later on.
        self.lam = lam
        self.linear_module = Linear(num_basis, input_size, bias=False)
        self.cost = MSELoss(reduction='sum')
        self.size_average_l1 = size_average_l1
        self.solver_type = solver_type
        self.solver_params = deepcopy(solver_params)

        # define the function for fista_custom.
        def f_fista(target: Tensor, code: Tensor, calc_grad):
            recon = self.linear_module(code)
            cost_recon = self.cost(recon, target).item()
            if not calc_grad:
                return cost_recon
            else:
                # compute grad.
                grad_this = linear(recon - target,
                                   self.linear_module.weight.t())
                grad_this *= 2
                return cost_recon, grad_this

        def g_fista(x: Tensor):
            cost = torch.abs(x) * self.lam
            if self.size_average_l1:
                cost = cost.mean()
            else:
                cost = cost.sum()
            return cost

        self.f_fista = f_fista
        self.g_fista = g_fista
        self.pl = lambda x, L: softshrink(x, lambd=self.lam / L).data

        if self.solver_type == 'fista_custom':
            self.L = 0.1  # init L
Beispiel #21
0
 def salsa_forward(self, data, return_loss=False):
     """
 Implements SALSA.
 """
     if return_loss:
         loss = []
     We_y = self.We(data)
     x = We_y
     d = torch.zeros(We_y.size()).to(self.device)
     for it in range(self.n_iter):
         u = F.softshrink(x + d, lambd=self.thresh)
         x = self.S(We_y + self.mu * (u - d))
         d += x - u
         if return_loss:
             loss += [self.lossFcn(x, data)]
     # Fin.
     if return_loss:
         return x, loss
     return x
Beispiel #22
0
 def test_softshrink(self):
     inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype)
     output = F.softshrink(inp, lambd=0.5)
Beispiel #23
0
 def __init__(self, lambd, beta, N_l, N_c):
     super(Line, self).__init__()
     self.weightx = nn.Parameter(torch.arange(-500, 510, 10, dtype=torch.float) / 255, requires_grad=False)
     l = lambd / max(beta, 1e-4)
     self.weighty = nn.Parameter(F.softshrink(self.weightx.view(1,1,-1).repeat(N_l, N_c, 1), l).contiguous())
Beispiel #24
0
def basis_pursuit_admm(A, b, lambd, M_inv=None, tol=1e-4, max_iters=100,
                       return_stats=False):
    r"""
    Basis Pursuit solver for the :math:`Q_1^\epsilon` problem

    .. math::
        \min_x \frac{1}{2} \left|\left| \boldsymbol{A}\vec{x} - \vec{b}
        \right|\right|_2^2 + \lambda \|x\|_1

    via the alternating direction method of multipliers (ADMM).

    Parameters
    ----------
    A : (N, M) torch.Tensor
        The input weight matrix :math:`\boldsymbol{A}`.
    b : (B, N) torch.Tensor
        The right side of the equation :math:`\boldsymbol{A}\vec{x} = \vec{b}`.
    lambd : float
        :math:`\lambda`, controls the sparsity of :math:`\vec{x}`.
    tol : float
        The accuracy tolerance of ADMM.
    max_iters : int
        Run for at most `max_iters` iterations.

    Returns
    -------
    torch.Tensor
        (B, M) The solution vector batch :math:`\vec{x}`.
    """
    A_dot_b = b.matmul(A)
    if M_inv is None:
        M = A.t().matmul(A) + torch.eye(A.shape[1], device=A.device)
        M_inv = M.inverse().t()
        del M

    batch_size = b.shape[0]
    v = torch.zeros(batch_size, A.shape[1], device=A.device)
    u = torch.zeros(batch_size, A.shape[1], device=A.device)
    v_prev = v.clone()

    v_solution = v.clone()
    solved = torch.zeros(batch_size, dtype=torch.bool)

    iter_id = 0
    dv_norm = None
    for iter_id in range(max_iters):
        b_eff = A_dot_b + v - u
        x = b_eff.matmul(M_inv)  # M_inv is already transposed
        # x is of shape (<=B, m_atoms)
        v = F.softshrink(x + u, lambd)
        u = u + x - v
        v_norm = v.norm(dim=1)
        if (v_norm == 0).any():
            warnings.warn(f"Lambda ({lambd}) is set too large: "
                          f"the output vector is zero-valued.")
        dv_norm = (v - v_prev).norm(dim=1) / (v_norm + 1e-9)
        solved_batch = dv_norm < tol
        v, u, A_dot_b = _reduce(solved, solved_batch, v_solution, v, u,
                                A_dot_b)
        if v.shape[0] == 0:
            # all solved
            break
        v_prev = v.clone()

    if iter_id != max_iters - 1:
        assert solved.all()
    v_solution[~solved] = v  # dump unsolved iterations

    if return_stats:
        return v_solution, dv_norm.mean(), iter_id

    return v_solution
Beispiel #25
0
    def step(self, normalizer=None, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                if normalizer is None:
                    grad = p.grad.data
                else:
                    grad = p.grad.data.div(normalizer)

                if grad.is_sparse:
                    raise RuntimeError(
                        'Adam does not support sparse gradients, please consider SparseAdam instead'
                    )
                amsgrad = group['amsgrad']
                lamda = group['lamda']
                normalize = group['normalize']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad.add_(group['weight_decay'], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']
                step_size = group['lr'] * math.sqrt(
                    bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

                if lamda != 0:
                    p.data = softshrink(p.data, lamda)

                if normalize:
                    p.data.div_(norm(p.data))

        return loss
Beispiel #26
0
def train_epoch(args, epoch, model, data_loader, optimizer, writer):
    model.train()
    avg_loss = 0.
    if epoch == args.TSP_epoch and args.TSP:
        x = model.get_trajectory()
        x = x.detach().cpu().numpy()
        if not args.KMEANS:
            x = tsp_solver(x)
        else:
            x = kmeans_tsp(x)
        v, a = get_vel_acc(x)
        writer.add_figure('TSP_Trajectory_Proj_XY', plot_trajectory(x, 0, 1),
                          epoch)
        writer.add_figure('TSP_Trajectory_Proj_YZ', plot_trajectory(x, 1, 2),
                          epoch)
        writer.add_figure('TSP_Trajectory_Proj_XZ', plot_trajectory(x, 0, 2),
                          epoch)
        writer.add_figure('TSP_Trajectory_3D', plot_trajectory(x, d3=True),
                          epoch)

        writer.add_figure('TSP_Acc', plot_acc(a, args.a_max), epoch)
        writer.add_figure('TSP_Vel', plot_acc(v, args.v_max), epoch)
        np.save('trajTSP', x)
        with torch.no_grad():
            model.subsampling.x.data = torch.tensor(x, device='cuda')
        args.a_max *= 2
        args.v_max *= 2

    if args.TSP and epoch > args.TSP_epoch and epoch <= args.TSP_epoch * 2:
        v0 = args.gamma * args.G_max * args.FOV * args.dt
        a0 = args.gamma * args.S_max * args.FOV * args.dt**2 * 1e3
        args.a_max -= a0 / args.TSP_epoch
        args.v_max -= v0 / args.TSP_epoch

    if args.TSP and epoch == args.TSP_epoch * 2:
        args.vel_weight *= 10
        args.acc_weight *= 10
    if args.TSP and epoch == args.TSP_epoch * 2 + 10:
        args.vel_weight *= 10
        args.acc_weight *= 10
    if args.TSP and epoch == args.TSP_epoch * 2 + 20:
        args.vel_weight *= 10
        args.acc_weight *= 10

    start_epoch = start_iter = time.perf_counter()
    for iter, data in enumerate(data_loader):
        input, target, mean, std, norm = data
        input = input.unsqueeze(1).to(args.device)
        target = target.to(args.device)

        output = model(input).squeeze(1)
        x = model.get_trajectory()
        v, a = get_vel_acc(x)

        acc_loss = args.acc_weight * torch.sqrt(
            torch.sum(torch.pow(F.softshrink(a, args.a_max), 2)))
        vel_loss = args.vel_weight * torch.sqrt(
            torch.sum(torch.pow(F.softshrink(v, args.v_max), 2)))
        rec_loss = args.rec_weight * F.l1_loss(output, target)
        if args.TSP and epoch < args.TSP_epoch:
            loss = rec_loss
        else:
            loss = rec_loss + vel_loss + acc_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss = 0.99 * avg_loss + 0.01 * loss.item(
        ) if iter > 0 else loss.item()
        #writer.add_scalar('TrainLoss', loss.item(), iter)

        if iter % args.report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                f'rec_loss: {rec_loss:.4g}, vel_loss: {vel_loss:.4g}, acc_loss: {acc_loss:.4g}, '
            )
        start_iter = time.perf_counter()
    return avg_loss, time.perf_counter() - start_epoch
Beispiel #27
0
def train_epoch(args, epoch, model, data_loader, optimizer, writer):
    model.train()
    avg_loss = 0.
    if epoch == args.TSP_epoch and args.TSP:
        x = model.module.get_trajectory()
        x = x.detach().cpu().numpy()
        for shot in range(x.shape[0]):
            x[shot, :, :] = tsp_solver(x[shot, :, :])
        v, a = get_vel_acc(x)
        writer.add_figure('TSP_Trajectory', plot_trajectory(x), epoch)
        writer.add_figure('TSP_Acc', plot_acc(a, args.a_max), epoch)
        writer.add_figure('TSP_Vel', plot_acc(v, args.v_max), epoch)
        np.save('trajTSP',x)
        with torch.no_grad():
            model.module.subsampling.x.data = torch.tensor(x, device='cuda')
        args.a_max *= 2
        args.v_max *= 2
        args.vel_weight = 1e-3
        args.acc_weight = 1e-3

    # if epoch == 30:
    #     v0 = args.gamma * args.G_max * args.FOV * args.dt
    #     a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3
    #     args.a_max = a0 *1.5
    #     args.v_max = v0 *1.5

    # if args.TSP and epoch > args.TSP_epoch and epoch<=args.TSP_epoch*2:
    #     v0 = args.gamma * args.G_max * args.FOV * args.dt
    #     a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3
    #     args.a_max -= a0/args.TSP_epoch
    #     args.v_max -= v0/args.TSP_epoch
    #
    # if args.TSP and epoch==args.TSP_epoch*2:
    #     v0 = args.gamma * args.G_max * args.FOV * args.dt
    #     a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3
    #     args.a_max = a0
    #     args.v_max = v0
    #     args.vel_weight *= 10
    #     args.acc_weight *= 10
    # if args.TSP and epoch==args.TSP_epoch*2+10:
    #     args.vel_weight *= 10
    #     args.acc_weight *= 10
    # if args.TSP and epoch==args.TSP_epoch*2+20:
    #     args.vel_weight *= 10
    #     args.acc_weight *= 10
    if args.TSP:
        if epoch < args.TSP_epoch:
            model.module.subsampling.interp_gap = 1
        elif epoch < 10 + args.TSP_epoch:
            model.module.subsampling.interp_gap = 10
            v0 = args.gamma * args.G_max * args.FOV * args.dt
            a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3
            args.a_max -= a0/args.TSP_epoch
            args.v_max -= v0/args.TSP_epoch
        elif epoch == 10 + args.TSP_epoch:
            model.module.subsampling.interp_gap = 10
            v0 = args.gamma * args.G_max * args.FOV * args.dt
            a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3
            args.a_max -= a0 / args.TSP_epoch
            args.v_max -= v0 / args.TSP_epoch
        elif epoch == 15 + args.TSP_epoch:
            model.module.subsampling.interp_gap = 10
            v0 = args.gamma * args.G_max * args.FOV * args.dt
            a0 = args.gamma * args.S_max * args.FOV * args.dt ** 2 * 1e3
            args.a_max -= a0 / args.TSP_epoch
            args.v_max -= v0 / args.TSP_epoch
        elif epoch == 20 + args.TSP_epoch:
            model.module.subsampling.interp_gap = 10
            args.vel_weight *= 10
            args.acc_weight *= 10
        elif epoch == 23 + args.TSP_epoch:
            model.module.subsampling.interp_gap = 5
            args.vel_weight *= 10
            args.acc_weight *= 10
        elif epoch == 25 + args.TSP_epoch:
            model.module.subsampling.interp_gap = 1
            args.vel_weight *= 10
            args.acc_weight *= 10
    else:
        if epoch < 10:
            model.module.subsampling.interp_gap = 50
        elif epoch == 10:
            model.module.subsampling.interp_gap = 30
        elif epoch == 15:
            model.module.subsampling.interp_gap = 20
        elif epoch == 20:
            model.module.subsampling.interp_gap = 10
        elif epoch == 23:
            model.module.subsampling.interp_gap = 5
        elif epoch == 25:
            model.module.subsampling.interp_gap = 1

    start_epoch = start_iter = time.perf_counter()
    print(f'a_max={args.a_max}, v_max={args.v_max}')
    for iter, data in enumerate(data_loader):
        optimizer.zero_grad()
        input, target, mean, std, norm = data
        input = input.to(args.device)
        target = target.to(args.device)

        output = model(input)
        # output = transforms.complex_abs(output)  # complex to real
        # output = transforms.root_sum_of_squares(output, dim=1)
        output=output.squeeze()

        x = model.module.get_trajectory()
        v, a = get_vel_acc(x)
        acc_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(a, args.a_max).abs()+1e-8, 2)))
        vel_loss = torch.sqrt(torch.sum(torch.pow(F.softshrink(v, args.v_max).abs()+1e-8, 2)))

        rec_loss = F.l1_loss(output, target)
        if args.TSP and epoch < args.TSP_epoch:
            loss = args.rec_weight * rec_loss
        else:
            loss = args.rec_weight * rec_loss + args.vel_weight * vel_loss + args.acc_weight * acc_loss

        loss.backward()
        optimizer.step()

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item() if iter > 0 else loss.item()
        # writer.add_scalar('TrainLoss', loss.item(), global_step + iter)

        if iter % args.report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{iter:4d}/{len(data_loader):4d}] '
                f'Loss = {loss.item():.4g} Avg Loss = {avg_loss:.4g} '
                f'rec_loss: {rec_loss:.4g}, vel_loss: {vel_loss:.4g}, acc_loss: {acc_loss:.4g}'
            )
        start_iter = time.perf_counter()
    return avg_loss, time.perf_counter() - start_epoch
Beispiel #28
0
''' tanh/ hardtanh/ softsign '''
x = torch.arange(-5,5,0.1).view(-1,1)
y = F.tanh(x)
pic(x, y, (2,4,3), 'tanh')
y = F.hardtanh(x)
pic(x, y, (2,4,3), 'hardtanh')
y = F.softsign(x)
pic(x, y, (2,4,3), 'softsign')

''' hardshrink/ tanhshrink/ tanhshrink '''
x = torch.arange(-3,3,0.1).view(-1,1)
y = F.hardshrink(x)
pic(x, y, (2,4,4), 'hardshrink')
y = F.tanhshrink(x)
pic(x, y, (2,4,4), 'tanhshrink')
y = F.softshrink(x)
pic(x, y, (2,4,4), 'tanhshrink')

''' sigmoid '''
x = torch.arange(-5,5,0.1).view(-1,1)
y = F.sigmoid(x)
pic(x, y, (2,4,5), 'sigmoid')

''' relu6 '''
x = torch.arange(-5,10,0.1).view(-1,1)
y = F.relu6(x)
pic(x, y, (2,4,6), 'relu6')

''' logsigmoid '''
x = torch.arange(-3,3,0.1).view(-1,1)
y = F.logsigmoid(x)
Beispiel #29
0
def coord_descent_mod(x, W, z0=None, alpha=1.0, max_iter=1000, tol=1e-4):
    """Modified variant of the CD algorithm

    Based on `enet_coordinate_descent` from sklearn.linear_model._cd_fast

    This version is much slower, but it produces more reliable results
    as compared to the above.

    x : Tensor of shape [n_samples, n_features]
    W : Tensor of shape [n_features, n_components]
    z : Tensor of shape [n_samples, n_components]
    """
    n_features, n_components = W.shape
    n_samples = x.shape[0]
    assert x.shape[1] == n_features
    if z0 is None:
        z = x.new_zeros(n_features, n_components)  # [N,K]
    else:
        assert z0.shape == (n_features, n_components)
        z = z0

    gap = z.new_full((n_samples,), tol + 1.)
    converged = z.new_zeros(n_samples, dtype=torch.bool)
    d_w_tol = tol
    tol = tol * x.pow(2).sum(1)  # [N,]

    # compute squared norms of the columns of X
    norm_cols_X = W.pow(2).sum(0)  # [K,]

    # function to check convergence state (per sample)
    def _check_convergence(z_, x_, R_, tol_):
        XtA = torch.mm(R_, W)  # [N,K]
        dual_norm_XtA = XtA.abs().max(1)[0]  # [N,]
        R_norm2 = R_.pow(2).sum(1)  # [N,]

        small_norm = dual_norm_XtA <= alpha
        const = (alpha / dual_norm_XtA).masked_fill(small_norm, 1.)
        gap = torch.where(small_norm, R_norm2, 0.5 * R_norm2 * (1 + const.pow(2)))

        gap = gap + alpha * z_.abs().sum(1) - const * (R_ * x_).sum(1)
        converged = gap < tol_

        return converged, gap

    # initialize residual
    R = x - torch.matmul(z, W.T) # [N,D]

    for n_iter in range(max_iter):
        if converged.all():
            break
        active_ix, = torch.where(~converged)
        z_max = z.new_zeros(len(active_ix))
        d_z_max = z.new_zeros(len(active_ix))
        for i in range(n_components):  # Loop over components
            if norm_cols_X[i] == 0:
                continue

            atom_i = W[:,i].contiguous()

            z_i = z[active_ix, i].clone()
            nonzero = z_i != 0
            R[active_ix[nonzero]] += torch.outer(z_i[nonzero], atom_i)

            z[active_ix, i] = F.softshrink(R[active_ix].matmul(atom_i), alpha)
            z[active_ix, i] /= norm_cols_X[i]

            z_new_i = z[active_ix, i]
            nonzero = z_new_i != 0
            R[active_ix[nonzero]] -= torch.outer(z_new_i[nonzero], atom_i)

            # update the maximum absolute coefficient update
            d_z_max = torch.maximum(d_z_max, (z_new_i - z_i).abs())
            z_max = torch.maximum(z_max, z_new_i.abs())

        ### check convergence ###
        check = (z_max == 0) | (d_z_max / z_max < d_w_tol) | (n_iter == max_iter-1)
        if not check.any():
            continue
        check_ix = active_ix[check]
        converged[check_ix], gap[check_ix] = \
            _check_convergence(z[check_ix], x[check_ix], R[check_ix], tol[check_ix])

    return z, gap
Beispiel #30
0
 def softshrink(self, x, lambd=0.5):
     return F.softshrink(x, lambd)