class LearningKernelwithRandomFeature(BaseEstimator, TransformerMixin):
    """Learns importance weights of random features by maximizing the 
    kernel alignment with a divergence constraint.

    Parameters
    ----------
    transformer : sklearn transformer object (default=None)
        A random feature map object.
        If None, RBFSampler is used.

    divergence : str (default='chi2')
        Which f-divergence to use.
        
        - 'chi2': (p/q)^2 -1

        - 'kl': (p/q log(p/q))

        - 'tv': |p/q-1|/2

        - 'squared': 0.5*(p-q)^2

        - 'reverse_kl': (- log (p/q))

    max_iter : int (default=100)
        Maximum number of iterations.

    rho : double (default=1.)
        A upper bound of divergence.
        If rho=0, the importance weights will be 1/\sqrt{n_components}.

    alpha : double (default=None)
        A strenght hyperparameter for divergence.
        If not None, optimize the regularized objective (+ alpha * divergence).
        If None, optimize the constraied objective (divergence < rho).

    tol : double (default=1e-8)
        Tolerance of stopping criterion.

    scaling : bool (default=True)
        Whether to scaling the kernel alignment term in objective or not.
        If True, the kernel alignment term is scaled to [0, 1] by dividing the
        maximum kernel alignment value.

    warm_start : bool (default=False)
        Whether to active warm-start or not.

    max_iter_admm : int (default=10000)
        Maximum number of iterations in the ADMM optimization.
        This is used if divergence='tv' or 'reverse_kl'.
    
    mu : double (default=10)
        A parameter for the ADMM optimization.
        Larger mu updates the penalty parameter more frequently.
        This is used if divergence='tv' or 'reverse_kl'.

    tau_incr : double (default=2)
        A parameter for the ADMM optimization.
        The penalty parameter updated by multiplying tau_incr.
        This is used if divergence='tv' or 'reverse_kl'.

    tau_decr : double (default=2)
        A parameter for the ADMM optimization.
        The penalty parameter updated by dividing tau_decr.
        This is used if divergence='tv' or 'reverse_kl'.
        
    eps_abs : double (default=1e-4)
        A parameter for the ADMM optimization.
        It is used for stopping criterion.
        This is used if divergence='tv' or 'reverse_kl'.
     
    eps_rel : double (default=1e-4)
        A parameter for the ADMM optimization.
        It is used for stopping criterion.
        This is used if divergence='tv' or 'reverse_kl'.

    verbose : bool (default=True)
        Verbose mode or not.

    Attributes
    ----------
    importance_weights_ : array, shape (n_components, )
        The learned importance weights.
    
    divergence_ : Divergence instance
        The divergence instance for optimization.

    References
    ----------
    [1] Learning Kernels with Random Features.
    Aman Sinha and John Duchi.
    In NIPS 2016.
    (https://papers.nips.cc/paper/6180-learning-kernels-with-random-features.pdf)

    """
    def __init__(self,
                 transformer=None,
                 divergence='chi2',
                 rho=1.,
                 alpha=None,
                 max_iter=100,
                 tol=1e-8,
                 warm_start=False,
                 max_iter_admm=10000,
                 mu=10,
                 tau_incr=2,
                 tau_decr=2,
                 eps_abs=1e-4,
                 eps_rel=1e-4,
                 verbose=True):
        self.transformer = transformer
        self.divergence = divergence
        self.max_iter = max_iter
        self.rho = rho
        self.alpha = alpha
        self.tol = tol
        self.warm_start = warm_start
        self.max_iter_admm = max_iter_admm
        self.verbose = verbose
        self.mu = mu
        self.tau_incr = tau_incr
        self.tau_decr = tau_decr
        self.eps_abs = eps_abs
        self.eps_rel = eps_rel
        self.is_removed = False
        self.n_components = None

    def fit(self, X, y):
        """Fitting the importance weights of each random basis. 

        Parameters
        ----------
        X : {array-like, sparse matrix}, shape (n_samples, n_features)
            Training data, where n_samples is the number of samples
            and n_features is the number of features.

        y : array, shape (n_samples, )
            Labels, where n_samples is the number of samples.
            y must be in {-1, 1}^n_samples.

        Returns
        -------
        self : object
            Returns the transformer.
        """
        X, y = check_X_y(X, y)
        y_kind = np.unique(y)
        if not (y_kind[0] == -1 and y_kind[1] == 1):
            raise ValueError("Each element in y must be -1 or 1.")

        if not (self.warm_start and hasattr(self, "divergence_")):
            self.divergence_ = self._get_divergence()

        if self.transformer is None:
            self.transformer = RBFSampler()

        if not (self.warm_start
                and hasattr(self.transformer, "random_weights_")):
            self.transformer.fit(X)

        if self.is_removed:
            warnings.warn("Some columns of random_weights_ were removed "
                          "and so transformer is refitted now.")
            n_components = self.importance_weights_.shape[0]
            self.transformer.set_params(n_components=n_components)
            self.transformer.fit(X)
            self.is_removed = False

        if self.n_components is None:
            self.n_components = self.transformer.n_components

        v = compute_X_trans_y(self.transformer, X, y)
        v /= np.max(v)  # scaling
        if not (self.warm_start and hasattr(self, "importance_weights_")):
            self.importance_weights_ = np.zeros(v.shape[0])

        self.divergence_.fit(self.importance_weights_, v)
        return self

    def transform(self, X):
        """Apply the approximate feature map to X.

        Parameters
        ----------
        X : {array-like, sparse matrix}, shape (n_samples, n_features)
            New data, where n_samples is the number of samples
            and n_features is the number of features.

        Returns
        -------
        X_new : array-like, shape (n_samples, n_components)
        """
        check_is_fitted(self, "importance_weights_")
        X_trans = self.transformer.transform(X)
        X_trans *= np.sqrt(X_trans.shape[1])
        if self.is_removed:
            indices = np.nonzero(self.importance_weights_)[0]
            return X_trans * np.sqrt(self.importance_weights_[indices])
        else:
            return X_trans * np.sqrt(self.importance_weights_)

    def remove_bases(self):
        """Remove the useless random bases according to importance_weights_.
        
        Parameters
        ----------

        Returns
        -------
        self : bool
            Whether to remove bases or not.
        """
        check_is_fitted(self, "importance_weights_")
        remove_indices = np.where(self.importance_weights_ == 0)[0]

        flag = False
        if self.is_removed:
            warnings.warn("Random bases have already been removed.")
        elif len(remove_indices) == 0:
            if self.verbose:
                print("importance_weights has no zero values. "
                      "Bases are not removed.")
        elif hasattr(self.transformer, '_remove_bases'):
            if self.transformer._remove_bases(remove_indices):
                self.is_removed = True
                self.n_components = len(self.importance_weights_.nonzero())
                flag = True
            else:
                if self.verbose:
                    print("Bases are not removed.")
        else:
            warnings.warn("transformer does not have _remove_bases method.")
        return flag

    def _get_divergence(self):
        if self.divergence == 'chi2':
            return Chi2(self.rho, self.alpha, self.tol, self.warm_start,
                        self.max_iter, self.verbose)
        elif self.divergence == 'kl':
            return KL(self.rho, self.alpha, self.tol, self.warm_start,
                      self.max_iter, self.verbose)
        elif self.divergence == 'tv':
            return TV(self.rho, self.alpha, self.tol, self.warm_start,
                      self.max_iter, self.verbose, self.max_iter_admm, self.mu,
                      self.tau_incr, self.tau_decr, self.eps_abs, self.eps_rel)
        elif self.divergence == 'chi2_origin':
            return Chi2Origin(self.rho, self.alpha, self.tol, self.warm_start,
                              self.max_iter, self.verbose)
        elif self.divergence == 'squared':
            return Squared(self.rho, self.alpha, self.tol, self.warm_start,
                           self.max_iter, self.verbose)
        elif self.divergence == 'reverse_kl':
            return ReverseKL(self.rho, self.alpha, self.tol, self.warm_start,
                             self.max_iter, self.verbose, self.max_iter_admm,
                             self.mu, self.tau_incr, self.tau_decr,
                             self.eps_abs, self.eps_rel)
        else:
            raise ValueError("f={} is not supported now. Use {'chi2'|'kl'|"
                             "'tv'|'reverse_kl'|'squared'}.".format(
                                 self.divergence))