Exemple #1
0
    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
        """
        compute the top_n eigenvalues using power iteration method
        maxIter: maximum iterations used to compute each single eigenvalue
        tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
        top_n: top top_n eigenvalues will be computed
        """

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0

        while computed_dim < top_n:
            eigenvalue = None
            v = [torch.randn(p.size()).to(device) for p in self.params
                ]  # generate random vector
            # print(v)
            v = normalization(v)  # normalize the vector

            for i in range(maxIter):
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                if self.full_dataset:
                    tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                else:
                    Hv = hessian_vector_product(self.gradsH, self.params, v)
                    tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)

                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6) < tol:
                        break
                    else:
                        eigenvalue = tmp_eigenvalue
            eigenvalues.append(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1

        return eigenvalues, eigenvectors
Exemple #2
0
    def density(self, iter=100, n_v=1):
        """
        compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ)
        iter: number of iterations used to compute trace
        n_v: number of SLQ runs
        """

        device = self.device
        eigen_list_full = []
        weight_list_full = []

        for k in range(n_v):
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            # generate Rademacher random variables
            for v_i in v:
                v_i[v_i == 0] = -1
            v = normalization(v)

            # standard lanczos algorithm initlization
            v_list = [v]
            w_list = []
            alpha_list = []
            beta_list = []
            ############### Lanczos
            for i in range(iter):
                self.model.zero_grad()
                w_prime = [torch.zeros(p.size()).to(device) for p in self.params]
                if i == 0:
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w = group_add(w_prime, v, alpha=-alpha)
                    w_list.append(w)
                else:
                    beta = torch.sqrt(group_product(w, w))
                    beta_list.append(beta.cpu().item())
                    if beta_list[-1] != 0.:
                        # We should re-orth it
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    else:
                        # generate a new vector
                        w = [torch.randn(p.size()).to(device) for p in self.params]
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w_tmp = group_add(w_prime, v, alpha=-alpha)
                    w = group_add(w_tmp, v_list[-2], alpha=-beta)

            T = torch.zeros(iter, iter).to(device)
            for i in range(len(alpha_list)):
                T[i, i] = alpha_list[i]
                if i < len(alpha_list) - 1:
                    T[i + 1, i] = beta_list[i]
                    T[i, i + 1] = beta_list[i]
            a_, b_ = torch.eig(T, eigenvectors=True)

            eigen_list = a_[:, 0]
            weight_list = b_[0, :]**2
            eigen_list_full.append(list(eigen_list.cpu().numpy()))
            weight_list_full.append(list(weight_list.cpu().numpy()))

        return eigen_list_full, weight_list_full
Exemple #3
0
    def density(self, iter=100, n_v=1, debug=False):
        """
        compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ)
        iter: number of iterations used to compute trace
        n_v: number of SLQ runs

        """

        device = self.device
        eigen_list_full = []
        weight_list_full = []

        # Prepare to record data
        if self.record_data:
            now = datetime.datetime.now()
            timestamp = "_{:02d}{:02d}_{:02d}{:02d}{:02d}".format(
                now.day, now.month, now.hour, now.minute, now.second)
            save_file = self.data_save_dir + "ESD" + timestamp + ".txt"

        start_time = time.time()
        for k in range(n_v):
            v = [
                torch.randint_like(p, high=2, device=device)
                for p in self.params
            ]
            # generate Rademacher random variables
            for v_i in v:
                v_i[v_i == 0] = -1
            v = normalization(v)

            # standard lanczos algorithm initlization
            v_list = [v]
            w_list = []
            alpha_list = []
            beta_list = []
            ############### Lanczos
            for i in range(iter):
                if debug:
                    print("Iteration {}".format(i))
                self.model.zero_grad()
                w_prime = [
                    torch.zeros(p.size()).to(device) for p in self.params
                ]
                if i == 0:
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w = group_add(w_prime, v, alpha=-alpha)
                    w_list.append(w)
                else:
                    beta = torch.sqrt(group_product(w, w))
                    beta_list.append(beta.cpu().item())
                    if beta_list[-1] != 0.:
                        # We should re-orth it
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    else:
                        # generate a new vector
                        w = [
                            torch.randn(p.size()).to(device)
                            for p in self.params
                        ]
                        v = orthnormal(w, v_list)
                        v_list.append(v)
                    if self.full_dataset:
                        _, w_prime = self.dataloader_hv_product(v)
                    else:
                        w_prime = hessian_vector_product(
                            self.gradsH, self.params, v)
                    alpha = group_product(w_prime, v)
                    alpha_list.append(alpha.cpu().item())
                    w_tmp = group_add(w_prime, v, alpha=-alpha)
                    w = group_add(w_tmp, v_list[-2], alpha=-beta)

            T = torch.zeros(iter, iter).to(device)
            for i in range(len(alpha_list)):
                T[i, i] = alpha_list[i]
                if i < len(alpha_list) - 1:
                    T[i + 1, i] = beta_list[i]
                    T[i, i + 1] = beta_list[i]
            a_, b_ = torch.eig(T, eigenvectors=True)

            eigen_list = a_[:, 0]
            weight_list = b_[0, :]**2
            eigen_list_full.append(list(eigen_list.cpu().numpy()))
            weight_list_full.append(list(weight_list.cpu().numpy()))
        # Write data if applicable
        stop_time = time.time()
        if self.record_data:
            with open(save_file, 'w') as f:
                f.write("Total Elapsed Time(s)\n")
                f.write("{}\n".format(stop_time - start_time))
        return eigen_list_full, weight_list_full
Exemple #4
0
    def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1, debug=False):
        """
        compute the top_n eigenvalues using power iteration method
        maxIter: maximum iterations used to compute each single eigenvalue
        tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
        top_n: top top_n eigenvalues will be computed
        """

        assert top_n >= 1

        device = self.device

        eigenvalues = []
        eigenvectors = []

        computed_dim = 0

        # Prepare to record data
        if self.record_data:
            now = datetime.datetime.now()
            timestamp = "_{:02d}{:02d}_{:02d}{:02d}{:02d}".format(
                now.day, now.month, now.hour, now.minute, now.second)
            save_file = self.data_save_dir + "TopEigen" + timestamp + ".txt"
            total_time_to_compute = []
            iters_to_compute = []

        start_time = time.time()
        while computed_dim < top_n:
            if debug:
                print("Computing eigenvalue #{}".format(computed_dim + 1))
            eigenvalue = None
            v = [torch.randn(p.size()).to(device)
                 for p in self.params]  # generate random vector
            v = normalization(v)  # normalize the vector

            for i in range(maxIter):
                if debug:
                    print("   Iteration {}".format(i))
                v = orthnormal(v, eigenvectors)
                self.model.zero_grad()

                if self.full_dataset:
                    tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
                else:
                    Hv = hessian_vector_product(self.gradsH, self.params, v)
                    tmp_eigenvalue = group_product(Hv, v).cpu().item()

                v = normalization(Hv)

                if eigenvalue == None:
                    eigenvalue = tmp_eigenvalue
                else:
                    if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                           1e-6) < tol:
                        break
                    else:
                        eigenvalue = tmp_eigenvalue
            # Record data
            total_time_to_compute.append(time.time() - start_time)
            iters_to_compute.append(i)
            eigenvalues.append(eigenvalue)
            eigenvectors.append(v)
            computed_dim += 1
        # Write data if applicable
        if self.record_data:
            with open(save_file, 'w') as f:
                f.write("Eigenvalue\tTotal Elapsed Time(s)\t#Iterations\n")
                for i in range(top_n):
                    f.write("{}\t{}\t{}\n".format(i + 1,
                                                  total_time_to_compute[i],
                                                  iters_to_compute[i]))
        return eigenvalues, eigenvectors
def train_hessian(args,
                  trainer,
                  task,
                  epoch_itr,
                  sample_iter=1,
                  maxIter=500,
                  tol=1e-4,
                  top_n=1,
                  ignore_grad=False):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[
        epoch_itr.epoch - 1] if epoch_itr.epoch <= len(
            args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    max_iters = 10
    samples_hessian = []
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if i > max_iters:
            break
        samples = [trainer._prepare_sample(sample) for sample in samples]
        samples_hessian.extend(samples)

    eigenvalues = []
    eigenvectors = []
    computed_dim = 0

    params, gradsH = get_params_grad(trainer.model)
    while computed_dim < top_n:
        eigenvalue = None
        v = [torch.randn(p.size()).cuda() for p in params]
        v = normalization(v)

        for i in range(maxIter):
            trainer.model.zero_grad()
            v = orthnormal(v, eigenvectors)
            loss, sample_size, logging_output, gradsH, tmp_eigenvalue, Hv = trainer.task.train_step_hessian(
                samples_hessian,
                trainer.model,
                trainer.criterion,
                trainer.optimizer,
                ignore_grad,
                v=v)
            v = normalization(Hv)
            if eigenvalue == None:
                eigenvalue = tmp_eigenvalue
            else:
                if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
                                                       1e-6) < tol:
                    break
                else:
                    eigenvalue = tmp_eigenvalue
        eigenvalues.append(eigenvalue)
        eigenvectors.append(v)
        computed_dim += 1
    return eigenvalues, eigenvectors