예제 #1
0
 def visualize_batch_np(self,
                        codes_all_arr,
                        truncation=None,
                        mean_latent=None,
                        B=15):
     if truncation is None: truncation = self.truncation
     if mean_latent is None: mean_latent = self.mean_latent
     if self.StyleGAN.size == 1024: B = round(B / 4)
     csr = 0
     img_all = None
     imgn = codes_all_arr.shape[0]
     while csr < imgn:
         csr_end = min(csr + B, imgn)
         with torch.no_grad():
             img_list = self.visualize(
                 torch.from_numpy(
                     codes_all_arr[csr:csr_end, :]).float().cuda(),
                 truncation=truncation,
                 mean_latent=mean_latent,
             ).cpu()
         img_all = img_list if img_all is None else torch.cat(
             (img_all, img_list), dim=0)
         csr = csr_end
         clear_output(wait=True)
         progress_bar(csr_end, imgn,
                      "ploting row of page: %d of %d" % (csr_end, imgn))
     return img_all
def get_full_hessian(loss_grad, model):
    # from https://discuss.pytorch.org/t/compute-the-hessian-matrix-of-a-network/15270/3
    cnt = 0
    loss_grad = list(loss_grad)
    for i, g in enumerate(loss_grad):
        progress_bar(
            i,
            len(loss_grad),
            "flattening to full gradient: %d of %d" % (i, len(loss_grad)),
        )
        g_vector = (g.contiguous().view(-1) if cnt == 0 else torch.cat(
            [g_vector, g.contiguous().view(-1)]))
        cnt = 1
    hessian_size = g_vector.size(0)
    hessian = torch.zeros(hessian_size, hessian_size)
    for idx in range(hessian_size):
        progress_bar(idx, hessian_size,
                     "full hessian columns: %d of %d" % (idx, hessian_size))
        grad2rd = torch.autograd.grad(g_vector[idx],
                                      model.parameters(),
                                      create_graph=True)
        cnt = 0
        for g in grad2rd:
            g2 = (g.contiguous().view(-1)
                  if cnt == 0 else torch.cat([g2, g.contiguous().view(-1)]))
            cnt = 1
        hessian[idx] = g2
    return hessian.cpu().data.numpy()
예제 #3
0
 def visualize_batch_np(self, codes_all_arr, truncation=0.7, B=5):
     csr = 0
     img_all = None
     imgn = codes_all_arr.shape[0]
     with torch.no_grad():
         while csr < imgn:
             csr_end = min(csr + B, imgn)
             img_list = self.visualize(torch.from_numpy(codes_all_arr[csr:csr_end, :]).float().cuda(),
                                        truncation=truncation, ).cpu()
             img_all = img_list if img_all is None else torch.cat((img_all, img_list), dim=0)
             csr = csr_end
             clear_output(wait=True)
             progress_bar(csr_end, imgn, "ploting row of page: %d of %d" % (csr_end, imgn))
     return img_all
def power_iteration(
    operator: Operator,
    steps: int = 20,
    error_threshold: float = 1e-4,
    momentum: float = 0.0,
    use_gpu: bool = True,
    fp16: bool = False,
    init_vec: torch.Tensor = None,
) -> Tuple[float, torch.Tensor]:
    """
    Compute dominant eigenvalue/eigenvector of a matrix
    operator: linear Operator giving us matrix-vector product access
    steps: number of update steps to take
    returns: (principal eigenvalue, principal eigenvector) pair
    """
    vector_size = operator.size  # input dimension of operator
    if init_vec is None:
        vec = torch.rand(vector_size)
    else:
        vec = init_vec

    vec = utils.maybe_fp16(vec, fp16)

    if use_gpu:
        vec = vec.cuda()

    prev_lambda = 0.0
    prev_vec = utils.maybe_fp16(torch.randn_like(vec), fp16)
    for i in range(steps):
        prev_vec = vec / (torch.norm(vec) + 1e-6)
        new_vec = utils.maybe_fp16(operator.apply(vec),
                                   fp16) - momentum * prev_vec
        # need to handle case where we end up in the nullspace of the operator.
        # in this case, we are done.
        if torch.norm(new_vec).item() == 0.0:
            return 0.0, new_vec
        lambda_estimate = vec.dot(new_vec).item()
        diff = lambda_estimate - prev_lambda
        vec = new_vec.detach() / torch.norm(new_vec)
        if lambda_estimate == 0.0:  # for low-rank
            error = 1.0
        else:
            error = np.abs(diff / lambda_estimate)
        utils.progress_bar(i, steps, "power iter error: %.4f" % error)
        if error < error_threshold:
            break
        prev_lambda = lambda_estimate
    return lambda_estimate, vec
예제 #5
0
def get_full_hessian(loss, param):
    # from https://discuss.pytorch.org/t/compute-the-hessian-matrix-of-a-network/15270/3
    # modified from hessian_eigenthings repo. api follows hessian.hessian
    hessian_size = param.numel()
    hessian = torch.zeros(hessian_size, hessian_size)
    loss_grad = torch.autograd.grad(loss,
                                    param,
                                    create_graph=True,
                                    retain_graph=True,
                                    only_inputs=True)[0].view(-1)
    for idx in range(hessian_size):
        clear_output(wait=True)
        progress_bar(idx, hessian_size,
                     "full hessian columns: %d of %d" % (idx, hessian_size))
        grad2rd = torch.autograd.grad(loss_grad[idx],
                                      param,
                                      create_graph=False,
                                      retain_graph=True,
                                      only_inputs=True)
        hessian[idx] = grad2rd[0].view(-1)
    return hessian.cpu().data.numpy()
예제 #6
0
    plt.ylabel("eigenvalue (log)")
    plt.legend()
    plt.suptitle("Hessian Spectrum Full Space")
    plt.tight_layout(pad=0.6)
    plt.savefig(join(savedir, "Hessian_sep_cls%d.jpg"%class_id))
    # plt.show()
    print("Spent %.1f sec from start" % (time() - T00))

    #%% Interpolation in the full space
    img_all = None
    for eigi in range(50): #eigvects.shape[1]
        interp_codes = LExpMap(ref_vect.cpu().numpy(), eigvects[:, -eigi-1], 11, (-2.5, 2.5))
        img_list = G.visualize(torch.from_numpy(interp_codes).float().cuda()).cpu()
        img_all = img_list if img_all is None else torch.cat((img_all, img_list), dim=0)
        clear_output(wait=True)
        progress_bar(eigi, 50, "ploting row of page: %d of %d" % (eigi, 256))

    imggrid = make_grid(img_all, nrow=11)
    PILimg = ToPILImage()(imggrid)#.show()
    PILimg.save(join(savedir, "eigvect_full_cls%d.jpg"%class_id))
    #% Interpolation in the class space
    img_all = None
    for eigi in range(50): # eigvects_clas.shape[1]
        interp_class = LExpMap(classvec.cpu().numpy(), eigvects_clas[:, -eigi-1], 11, (-2.5, 2.5))
        interp_codes = np.hstack((noisevec.cpu().numpy().repeat(11, axis=0), interp_class, ))
        img_list = G.visualize(torch.from_numpy(interp_codes).float().cuda()).cpu()
        img_all = img_list if img_all is None else torch.cat((img_all, img_list), dim=0)
        clear_output(wait=True)
        progress_bar(eigi, 50, "ploting row of page: %d of %d" % (eigi, 128))

    imggrid = make_grid(img_all, nrow=11)