class DensityEstimator:
    def __init__(self,
                 training_set,
                 method_name,
                 n_components=None,
                 log_dir=None,
                 second_stage_beta=None):
        self.log_dir = log_dir
        self.training_set = training_set
        self.fitting_done = False
        self.method_name = method_name
        self.second_density_mdl = None
        self.skip_fitting_and_sampling = False
        if method_name == "GMM_Dirichlet":
            self.model = mixture.BayesianGaussianMixture(
                n_components=n_components,
                covariance_type='full',
                weight_concentration_prior=1.0 / n_components)
        elif method_name == "GMM":
            self.model = mixture.GaussianMixture(n_components=n_components,
                                                 covariance_type='full',
                                                 max_iter=2000,
                                                 verbose=2,
                                                 tol=1e-3)
        elif method_name == "GMM_1":
            self.model = mixture.GaussianMixture(n_components=1,
                                                 covariance_type='full',
                                                 max_iter=2000,
                                                 verbose=2,
                                                 tol=1e-3)
        elif method_name == "GMM_10":
            self.model = mixture.GaussianMixture(n_components=10,
                                                 covariance_type='full',
                                                 max_iter=2000,
                                                 verbose=2,
                                                 tol=1e-3)
        elif method_name == "GMM_20":
            self.model = mixture.GaussianMixture(n_components=20,
                                                 covariance_type='full',
                                                 max_iter=2000,
                                                 verbose=2,
                                                 tol=1e-3)
        elif method_name == "GMM_100":
            self.model = mixture.GaussianMixture(n_components=100,
                                                 covariance_type='full',
                                                 max_iter=2000,
                                                 verbose=2,
                                                 tol=1e-3)
        elif method_name == "GMM_200":
            self.model = mixture.GaussianMixture(n_components=200,
                                                 covariance_type='full',
                                                 max_iter=2000,
                                                 verbose=2,
                                                 tol=1e-3)

        elif method_name.find("aux_vae") >= 0:
            have_2nd_density_est = False
            if method_name[8:] != "":
                self.second_density_mdl = method_name[8:]
                have_2nd_density_est = True
            self.model = VaeModelWrapper(
                input_shape=(training_set.shape[-1], ),
                latent_space_dim=training_set.shape[-1],
                have_2nd_density_est=have_2nd_density_est,
                log_dir=self.log_dir,
                sec_stg_beta=second_stage_beta)

        elif method_name == "given_zs":
            files = os.listdir(log_dir)
            for z_smpls in files:
                if z_smpls.endswith('.npy'):
                    break
            self.z_smps = np.load(os.path.join(log_dir, z_smpls))
            self.skip_fitting_and_sampling = True

        elif method_name.upper() == "KDE":
            self.model = KernelDensity(kernel='gaussian', bandwidth=0.425)
            # self.model = KernelDensity(kernel='tophat', bandwidth=15)
        else:
            raise NotImplementedError("Method specified : " +
                                      str(method_name) +
                                      " doesn't have an implementation yet.")

    def fitorload(self, file_name=None):
        if not self.skip_fitting_and_sampling:
            if file_name is None:
                self.model.fit(self.training_set, self.second_density_mdl)
            else:
                self.model.load(file_name)

        self.fitting_done = True

    def score(self, X, y=None):
        if self.method_name.upper().find(
                "AUX_VAE") >= 0 or self.skip_fitting_and_sampling:
            raise NotImplementedError(
                "Log likelihood evaluation for VAE is difficult. or skipped")
        else:
            return self.model.score(X, y)

    def save(self, file_name):
        if not self.skip_fitting_and_sampling:
            if self.method_name.find('vae') >= 0:
                self.model.save(file_name)
            else:
                with open(file_name, 'wb') as f:
                    pickle.dump(self.model, f)

    def reconstruct(self, input_batch):
        if self.method_name.upper().find("AUX_VAE") < 0:
            raise ValueError("Non autoencoder style density estimator: " +
                             self.method_name)
        return self.model.reconstruct(input_batch)

    def get_samples(self, n_samples):
        if not self.skip_fitting_and_sampling:
            if not self.fitting_done:
                self.fitorload()
            scrmb_idx = np.array(range(n_samples))
            np.random.shuffle(scrmb_idx)
            if self.log_dir is not None:
                pickle_path = os.path.join(self.log_dir,
                                           self.method_name + '_mdl.pkl')
                with open(pickle_path, 'wb') as f:
                    pickle.dump(self.model, f)
            if self.method_name.upper() == "GMM_DIRICHLET" or self.method_name.upper() == "AUX_VAE" \
                    or self.method_name.upper() == "GMM" or self.method_name.upper() == "GMM_1" \
                    or self.method_name.upper() == "GMM_10" or self.method_name.upper() == "GMM_20" \
                    or self.method_name.upper() == "GMM_100" or self.method_name.upper() == "GMM_200"\
                    or self.method_name.upper().find("AUX_VAE") >= 0:
                return self.model.sample(n_samples)[0][scrmb_idx, :]
            else:
                return np.random.shuffle(
                    self.model.sample(n_samples))[scrmb_idx, :]
        else:
            return self.z_smps