Exemplo n.º 1
0
    def __init__(self, config: ConfigDict, train_samples: np.ndarray, recorder: Recorder,
                 val_samples: np.ndarray = None, seed: int = 0):
        # supress tensorflow casting warnings
        logging.getLogger("tensorflow").setLevel(logging.ERROR)

        self.c = config
        self.c.finalize_modifying()

        self._recorder = recorder

        self._context_dim = train_samples[0].shape[-1]
        self._sample_dim = train_samples[1].shape[-1]
        self._train_contexts = tf.constant(train_samples[0], dtype=tf.float32)

        c_net_hidden_dict = {NetworkKeys.NUM_UNITS: self.c.components_net_hidden_layers,
                             NetworkKeys.ACTIVATION: "relu",
                             NetworkKeys.BATCH_NORM: False,
                             NetworkKeys.DROP_PROB: self.c.components_net_drop_prob,
                             NetworkKeys.L2_REG_FACT: self.c.components_net_reg_loss_fact}

        g_net_hidden_dict = {NetworkKeys.NUM_UNITS: self.c.gating_net_hidden_layers,
                             NetworkKeys.ACTIVATION: "relu",
                             NetworkKeys.BATCH_NORM: False,
                             NetworkKeys.DROP_PROB: self.c.gating_net_drop_prob,
                             NetworkKeys.L2_REG_FACT: self.c.gating_net_reg_loss_fact}

        self._model = GaussianEMM(self._context_dim, self._sample_dim, self.c.num_components,
                                  c_net_hidden_dict, g_net_hidden_dict, seed=seed)
        self._c_opts = [k.optimizers.Adam(self.c.components_learning_rate, 0.5) for _ in self._model.components]
        self._g_opt = k.optimizers.Adam(self.c.gating_learning_rate, 0.5)

        dre_params = {NetworkKeys.NUM_UNITS: self.c.dre_hidden_layers,
                      NetworkKeys.ACTIVATION: k.activations.relu,
                      NetworkKeys.DROP_PROB: self.c.dre_drop_prob,
                      NetworkKeys.L2_REG_FACT: self.c.dre_reg_loss_fact}
        self._dre = DensityRatioEstimator(target_train_samples=train_samples,
                                          hidden_params=dre_params,
                                          early_stopping=self.c.dre_early_stopping, target_val_samples=val_samples,
                                          conditional_model=True)

        self._recorder.initialize_module(rec.INITIAL)
        self._recorder(rec.INITIAL, "Nonlinear Conditional EIM - Reparametrization", config)
        self._recorder.initialize_module(rec.MODEL, self.c.train_epochs)
        self._recorder.initialize_module(rec.WEIGHTS_UPDATE, self.c.train_epochs)
        self._recorder.initialize_module(rec.COMPONENT_UPDATE, self.c.train_epochs, self.c.num_components)
        self._recorder.initialize_module(rec.DRE, self.c.train_epochs)
Exemplo n.º 2
0
    def __init__(self, config, train_samples, recorder, val_samples=None, seed=0, add_feat_fn=None):

        self.c = config
        self.c.finalize_modifying()

        # build model
        w, m, c = model_init.gmm_init(self.c.initialization, np.array(train_samples, dtype=np.float32),
                                      self.c.num_components, seed=0)
        self._model = GMM(w, m, c)

        self._components_learners = []
        for i in range(self._model.num_components):
            self._components_learners.append(MoreGaussian(m.shape[-1], 1.0, 0.0, False))
        self._weight_learner = RepsCategorical(1.0, 0.0, False)

        # build density ratio estimator
        dre_params = {NetworkKeys.NUM_UNITS: self.c.dre_hidden_layers,
                      NetworkKeys.ACTIVATION: k.activations.relu,
                      NetworkKeys.DROP_PROB: self.c.dre_drop_prob,
                      NetworkKeys.L2_REG_FACT: self.c.dre_reg_loss_fact}
        if add_feat_fn is not None:
            self._dre = AddFeatDensityRatioEstimator(target_train_samples=train_samples, hidden_params=dre_params,
                                                     early_stopping=self.c.dre_early_stopping,
                                                     target_val_samples=val_samples, additional_feature_fn=add_feat_fn)
            name = "Marginal EIM - Additional Features"
        else:
            self._dre = DensityRatioEstimator(target_train_samples=train_samples, hidden_params=dre_params,
                                              early_stopping=self.c.dre_early_stopping, target_val_samples=val_samples)
            name = "Marginal EIM"

        # build recording
        self._recorder = recorder
        self._recorder.initialize_module(rec.INITIAL)
        self._recorder(rec.INITIAL, name, config)
        self._recorder.initialize_module(rec.MODEL, self.c.train_epochs)
        self._recorder.initialize_module(rec.WEIGHTS_UPDATE, self.c.train_epochs)
        self._recorder.initialize_module(rec.COMPONENT_UPDATE, self.c.train_epochs, self.c.num_components)
        self._recorder.initialize_module(rec.DRE, self.c.train_epochs)
Exemplo n.º 3
0
class ConditionalMixtureEIM:
    @staticmethod
    def get_default_config() -> ConfigDict:
        c = ConfigDict(
            num_components=1,
            train_epochs=1000,
            # Component
            components_learning_rate=1e-3,
            components_batch_size=1000,
            components_num_epochs=10,
            components_net_reg_loss_fact=0.,
            components_net_drop_prob=0.0,
            components_net_hidden_layers=[50, 50],
            # Gating
            gating_learning_rate=1e-3,
            gating_batch_size=1000,
            gating_num_epochs=10,
            gating_net_reg_loss_fact=0.,
            gating_net_drop_prob=0.0,
            gating_net_hidden_layers=[50, 50],
            # Density Ratio Estimation
            dre_reg_loss_fact=
            0.0,  # Scaling Factor for L2 regularization of density ratio estimator
            dre_early_stopping=
            True,  # Use early stopping for density ratio estimator training
            dre_drop_prob=
            0.0,  # If smaller than 1 dropout with keep prob = 'keep_prob' is used
            dre_num_iters=
            1000,  # Number of density ratio estimator steps each iteration (i.e. max number if early stopping)
            dre_batch_size=
            1000,  # Batch size for density ratio estimator training
            dre_hidden_layers=[
                30, 30
            ]  # width of density ratio estimator  hidden layers
        )
        c.finalize_adding()
        return c

    def __init__(self,
                 config: ConfigDict,
                 train_samples: np.ndarray,
                 recorder: Recorder,
                 val_samples: np.ndarray = None,
                 seed: int = 0):
        # supress tensorflow casting warnings
        logging.getLogger("tensorflow").setLevel(logging.ERROR)

        self.c = config
        self.c.finalize_modifying()

        self._recorder = recorder

        self._context_dim = train_samples[0].shape[-1]
        self._sample_dim = train_samples[1].shape[-1]
        self._train_contexts = tf.constant(train_samples[0], dtype=tf.float32)

        c_net_hidden_dict = {
            NetworkKeys.NUM_UNITS: self.c.components_net_hidden_layers,
            NetworkKeys.ACTIVATION: "relu",
            NetworkKeys.BATCH_NORM: False,
            NetworkKeys.DROP_PROB: self.c.components_net_drop_prob,
            NetworkKeys.L2_REG_FACT: self.c.components_net_reg_loss_fact
        }

        g_net_hidden_dict = {
            NetworkKeys.NUM_UNITS: self.c.gating_net_hidden_layers,
            NetworkKeys.ACTIVATION: "relu",
            NetworkKeys.BATCH_NORM: False,
            NetworkKeys.DROP_PROB: self.c.gating_net_drop_prob,
            NetworkKeys.L2_REG_FACT: self.c.gating_net_reg_loss_fact
        }

        self._model = GaussianEMM(self._context_dim,
                                  self._sample_dim,
                                  self.c.num_components,
                                  c_net_hidden_dict,
                                  g_net_hidden_dict,
                                  seed=seed)
        self._c_opts = [
            k.optimizers.Adam(self.c.components_learning_rate, 0.5)
            for _ in self._model.components
        ]
        self._g_opt = k.optimizers.Adam(self.c.gating_learning_rate, 0.5)

        dre_params = {
            NetworkKeys.NUM_UNITS: self.c.dre_hidden_layers,
            NetworkKeys.ACTIVATION: k.activations.relu,
            NetworkKeys.DROP_PROB: self.c.dre_drop_prob,
            NetworkKeys.L2_REG_FACT: self.c.dre_reg_loss_fact
        }
        self._dre = DensityRatioEstimator(
            target_train_samples=train_samples,
            hidden_params=dre_params,
            early_stopping=self.c.dre_early_stopping,
            target_val_samples=val_samples,
            conditional_model=True)

        self._recorder.initialize_module(rec.INITIAL)
        self._recorder(rec.INITIAL,
                       "Nonlinear Conditional EIM - Reparametrization", config)
        self._recorder.initialize_module(rec.MODEL, self.c.train_epochs)
        self._recorder.initialize_module(rec.WEIGHTS_UPDATE,
                                         self.c.train_epochs)
        self._recorder.initialize_module(rec.COMPONENT_UPDATE,
                                         self.c.train_epochs,
                                         self.c.num_components)
        self._recorder.initialize_module(rec.DRE, self.c.train_epochs)

    def train(self):
        for i in range(self.c.train_epochs):
            self._recorder(rec.TRAIN_ITER, i)
            self.train_iter(i)

    #   extra function to allow running from cluster work
    def train_iter(self, i):

        dre_steps, loss, acc = self._dre.train(self._model,
                                               self.c.dre_batch_size,
                                               self.c.dre_num_iters)
        self._recorder(rec.DRE, self._dre, self._model, i, dre_steps)

        if self._model.num_components > 1:
            w_res = self.update_gating()
            self._recorder(rec.WEIGHTS_UPDATE, w_res)

        c_res = self.update_components()
        self._recorder(rec.COMPONENT_UPDATE, c_res)

        self._recorder(rec.MODEL, self._model, i)

    """component update"""

    def update_components(self):
        importance_weights = self._model.gating_distribution.probabilities(
            self._train_contexts)
        importance_weights = importance_weights / tf.reduce_sum(
            importance_weights, axis=0, keepdims=True)

        old_means, old_chol_covars = self._model.get_component_parameters(
            self._train_contexts)

        rhs = tf.eye(tf.shape(old_means)[-1],
                     batch_shape=tf.shape(old_chol_covars)[:-2])
        stab_fact = 1e-20
        old_chol_inv = tf.linalg.triangular_solve(
            old_chol_covars + stab_fact * rhs, rhs)

        for i in range(self.c.components_num_epochs):
            self._components_train_step(importance_weights, old_means,
                                        old_chol_inv)

        res_list = []
        for i, c in enumerate(self._model.components):
            expected_entropy = np.sum(importance_weights[:, i] *
                                      c.entropies(self._train_contexts))
            kls = c.kls_other_chol_inv(self._train_contexts, old_means[:, i],
                                       old_chol_inv[:, i])
            expected_kl = np.sum(importance_weights[:, i] * kls)
            res_list.append((self.c.components_num_epochs, expected_kl,
                             expected_entropy, ""))
        return res_list

    @tf.function
    def _components_train_step(self, importance_weights, old_means,
                               old_chol_precisions):
        for i in range(self._model.num_components):
            dt = (self._train_contexts, importance_weights[:, i], old_means,
                  old_chol_precisions)
            data = tf.data.Dataset.from_tensor_slices(dt)
            data = data.shuffle(self._train_contexts.shape[0]).batch(
                self.c.components_batch_size)

            for context_batch, iw_batch, old_means_batch, old_chol_precisions_batch in data:
                iw_batch = iw_batch / tf.reduce_sum(iw_batch)
                with tf.GradientTape() as tape:
                    samples = self._model.components[i].sample(context_batch)
                    losses = -tf.squeeze(
                        self._dre(tf.concat([context_batch, samples],
                                            axis=-1)))
                    kls = self._model.components[i].kls_other_chol_inv(
                        context_batch, old_means_batch[:, i],
                        old_chol_precisions_batch[:, i])
                    loss = tf.reduce_mean(iw_batch * (losses + kls))
                gradients = tape.gradient(
                    loss, self._model.components[i].trainable_variables)
                self._c_opts[i].apply_gradients(
                    zip(gradients,
                        self._model.components[i].trainable_variables))

    """gating update"""

    def update_gating(self):
        old_probs = self._model.gating_distribution.probabilities(
            self._train_contexts)
        for i in range(self.c.gating_num_epochs):
            self._gating_train_step(old_probs)

        expected_entropy = self._model.gating_distribution.expected_entropy(
            self._train_contexts)
        expected_kl = self._model.gating_distribution.expected_kl(
            self._train_contexts, old_probs)

        return i + 1, expected_kl, expected_entropy, ""

    @tf.function
    def _gating_train_step(self, old_probs):
        losses = []
        for i in range(self.c.num_components):
            samples = self._model.components[i].sample(self._train_contexts)
            losses.append(-self._dre(
                tf.concat([self._train_contexts, samples], axis=-1)))

        losses = tf.concat(losses, axis=1)
        data = tf.data.Dataset.from_tensor_slices(
            (self._train_contexts, losses, old_probs))
        data = data.shuffle(self._train_contexts.shape[0]).batch(
            self.c.gating_batch_size)
        for context_batch, losses_batch, old_probs_batch in data:
            with tf.GradientTape() as tape:
                probabilities = self._model.gating_distribution.probabilities(
                    context_batch)
                kl = self._model.gating_distribution.expected_kl(
                    context_batch, old_probs_batch)
                loss = tf.reduce_sum(
                    tf.reduce_mean(probabilities * losses_batch, 0)) + kl
            gradients = tape.gradient(
                loss, self._model.gating_distribution.trainable_variables)
            self._g_opt.apply_gradients(
                zip(gradients,
                    self._model.gating_distribution.trainable_variables))

    @property
    def model(self):
        return self._model
Exemplo n.º 4
0
    def __init__(self,
                 config,
                 train_samples,
                 recorder,
                 val_samples=None,
                 seed=0,
                 add_feat_fn=None):

        self.c = config
        self.c.finalize_modifying()

        # build model
        w, m, c = model_init.gmm_init(self.c.initialization,
                                      np.array(train_samples,
                                               dtype=np.float32),
                                      self.c.num_components,
                                      seed=0)
        self._gmm_learner = GMMLearner(dim=m.shape[-1],
                                       num_components=self.c.num_components,
                                       surrogate_reg_fact=1e-10,
                                       seed=seed,
                                       eta_offset=1.0,
                                       omega_offset=0.0,
                                       constrain_entropy=False)
        self._gmm_learner.initialize_model(w.astype(np.float64),
                                           m.astype(np.float64),
                                           c.astype(np.float64))

        # build density ratio estimator
        dre_params = {
            NetworkKeys.NUM_UNITS: self.c.dre_hidden_layers,
            NetworkKeys.ACTIVATION: k.activations.relu,
            NetworkKeys.DROP_PROB: self.c.dre_drop_prob,
            NetworkKeys.L2_REG_FACT: self.c.dre_reg_loss_fact
        }
        if add_feat_fn is not None:
            self._dre = AddFeatDensityRatioEstimator(
                target_train_samples=train_samples,
                hidden_params=dre_params,
                early_stopping=self.c.dre_early_stopping,
                target_val_samples=val_samples,
                additional_feature_fn=add_feat_fn)
            name = "Marginal EIM - Additional Features"
        else:
            self._dre = DensityRatioEstimator(
                target_train_samples=train_samples,
                hidden_params=dre_params,
                early_stopping=self.c.dre_early_stopping,
                target_val_samples=val_samples)
            name = "Marginal EIM"

        # build recording
        self._recorder = recorder
        self._recorder.initialize_module(rec.INITIAL)
        self._recorder(rec.INITIAL, name, config)
        self._recorder.initialize_module(rec.MODEL, self.c.train_epochs)
        self._recorder.initialize_module(rec.WEIGHTS_UPDATE,
                                         self.c.train_epochs)
        self._recorder.initialize_module(rec.COMPONENT_UPDATE,
                                         self.c.train_epochs,
                                         self.c.num_components)
        self._recorder.initialize_module(rec.DRE, self.c.train_epochs)
Exemplo n.º 5
0
class MarginalMixtureEIM:
    @staticmethod
    def get_default_config():
        c = ConfigDict(
            num_components=1,
            samples_per_component=500,
            train_epochs=1000,
            initialization="random",
            # Component Updates
            component_kl_bound=0.01,
            # Mixture Updates
            weight_kl_bound=0.01,
            # Density Ratio Estimation
            dre_reg_loss_fact=
            0.0,  # Scaling Factor for L2 regularization of density ratio estimator
            dre_early_stopping=
            True,  # Use early stopping for density ratio estimator training
            dre_drop_prob=
            0.0,  # If smaller than 1 dropout with keep prob = 'keep_prob' is used
            dre_num_iters=
            1000,  # Number of density ratio estimator steps each iteration (i.e. max number if early stopping)
            dre_batch_size=
            1000,  # Batch size for density ratio estimator training
            dre_hidden_layers=[
                30, 30
            ]  # width of density ratio estimator  hidden layers
        )
        c.finalize_adding()
        return c

    def __init__(self,
                 config,
                 train_samples,
                 recorder,
                 val_samples=None,
                 seed=0,
                 add_feat_fn=None):

        self.c = config
        self.c.finalize_modifying()

        # build model
        w, m, c = model_init.gmm_init(self.c.initialization,
                                      np.array(train_samples,
                                               dtype=np.float32),
                                      self.c.num_components,
                                      seed=0)
        self._gmm_learner = GMMLearner(dim=m.shape[-1],
                                       num_components=self.c.num_components,
                                       surrogate_reg_fact=1e-10,
                                       seed=seed,
                                       eta_offset=1.0,
                                       omega_offset=0.0,
                                       constrain_entropy=False)
        self._gmm_learner.initialize_model(w.astype(np.float64),
                                           m.astype(np.float64),
                                           c.astype(np.float64))

        # build density ratio estimator
        dre_params = {
            NetworkKeys.NUM_UNITS: self.c.dre_hidden_layers,
            NetworkKeys.ACTIVATION: k.activations.relu,
            NetworkKeys.DROP_PROB: self.c.dre_drop_prob,
            NetworkKeys.L2_REG_FACT: self.c.dre_reg_loss_fact
        }
        if add_feat_fn is not None:
            self._dre = AddFeatDensityRatioEstimator(
                target_train_samples=train_samples,
                hidden_params=dre_params,
                early_stopping=self.c.dre_early_stopping,
                target_val_samples=val_samples,
                additional_feature_fn=add_feat_fn)
            name = "Marginal EIM - Additional Features"
        else:
            self._dre = DensityRatioEstimator(
                target_train_samples=train_samples,
                hidden_params=dre_params,
                early_stopping=self.c.dre_early_stopping,
                target_val_samples=val_samples)
            name = "Marginal EIM"

        # build recording
        self._recorder = recorder
        self._recorder.initialize_module(rec.INITIAL)
        self._recorder(rec.INITIAL, name, config)
        self._recorder.initialize_module(rec.MODEL, self.c.train_epochs)
        self._recorder.initialize_module(rec.WEIGHTS_UPDATE,
                                         self.c.train_epochs)
        self._recorder.initialize_module(rec.COMPONENT_UPDATE,
                                         self.c.train_epochs,
                                         self.c.num_components)
        self._recorder.initialize_module(rec.DRE, self.c.train_epochs)

    @property
    def model(self):
        return self._gmm_learner.model

    def train(self):
        for i in range(self.c.train_epochs):
            self._recorder(rec.TRAIN_ITER, i)

            # update density ratio estimator
            dre_steps, loss, acc = self._dre.train(self.model,
                                                   self.c.dre_batch_size,
                                                   self.c.dre_num_iters)
            self._recorder(rec.DRE, self._dre, self.model, i, dre_steps)

            # M-Step weights
            if self.model.num_components > 1:
                w_res = self.update_weight()
                self._recorder(rec.WEIGHTS_UPDATE, w_res)

            # M-Step components
            c_res = self.update_components()
            self._recorder(rec.COMPONENT_UPDATE, c_res)

            self._recorder(rec.MODEL, self.model, i)

    def _get_reward(self, samples_as_list):
        rewards = []
        for i, samples in enumerate(samples_as_list):
            rewards.append(self._dre(samples))
        return rewards

    def _get_samples(self):
        samples = []
        for c in self.model.components:
            samples.append(c.sample(self.c.samples_per_component))
        return samples

    def update_components(self):
        fake_samples = self._get_samples()
        rewards = self._get_reward(fake_samples)

        fake_samples = np.ascontiguousarray(
            [np.array(s, dtype=np.float64) for s in fake_samples])
        rewards = np.ascontiguousarray(
            [np.array(r, dtype=np.float64) for r in rewards])

        res_list = self._gmm_learner.update_components(
            fake_samples, rewards, self.c.component_kl_bound)

        return [(1, res[0], res[1], res[-1]) for res in res_list]

    def update_weight(self):
        fake_samples = self._get_samples()
        rewards = np.mean(self._get_reward(fake_samples), 1)
        rewards = np.ascontiguousarray(
            np.concatenate(rewards, -1).astype(np.float64))
        res = self._gmm_learner.update_weights(rewards, self.c.weight_kl_bound)
        return 1, res[0], res[1], res[-1]
Exemplo n.º 6
0
class MarginalMixtureEIM:

    @staticmethod
    def get_default_config():
        c = ConfigDict(
            num_components=1,
            samples_per_component=500,
            train_epochs=1000,
            initialization="random",
            # Component Updates
            component_kl_bound=0.01,
            # Mixture Updates
            weight_kl_bound=0.01,
            # Density Ratio Estimation
            dre_reg_loss_fact=0.0,  # Scaling Factor for L2 regularization of density ratio estimator
            dre_early_stopping=True,  # Use early stopping for density ratio estimator training
            dre_drop_prob=0.0,  # If smaller than 1 dropout with keep prob = 'keep_prob' is used
            dre_num_iters=1000,  # Number of density ratio estimator steps each iteration (i.e. max number if early stopping)
            dre_batch_size=1000,  # Batch size for density ratio estimator training
            dre_hidden_layers=[30, 30]  # width of density ratio estimator  hidden layers
        )
        c.finalize_adding()
        return c

    def __init__(self, config, train_samples, recorder, val_samples=None, seed=0, add_feat_fn=None):

        self.c = config
        self.c.finalize_modifying()

        # build model
        w, m, c = model_init.gmm_init(self.c.initialization, np.array(train_samples, dtype=np.float32),
                                      self.c.num_components, seed=0)
        self._model = GMM(w, m, c)

        self._components_learners = []
        for i in range(self._model.num_components):
            self._components_learners.append(MoreGaussian(m.shape[-1], 1.0, 0.0, False))
        self._weight_learner = RepsCategorical(1.0, 0.0, False)

        # build density ratio estimator
        dre_params = {NetworkKeys.NUM_UNITS: self.c.dre_hidden_layers,
                      NetworkKeys.ACTIVATION: k.activations.relu,
                      NetworkKeys.DROP_PROB: self.c.dre_drop_prob,
                      NetworkKeys.L2_REG_FACT: self.c.dre_reg_loss_fact}
        if add_feat_fn is not None:
            self._dre = AddFeatDensityRatioEstimator(target_train_samples=train_samples, hidden_params=dre_params,
                                                     early_stopping=self.c.dre_early_stopping,
                                                     target_val_samples=val_samples, additional_feature_fn=add_feat_fn)
            name = "Marginal EIM - Additional Features"
        else:
            self._dre = DensityRatioEstimator(target_train_samples=train_samples, hidden_params=dre_params,
                                              early_stopping=self.c.dre_early_stopping, target_val_samples=val_samples)
            name = "Marginal EIM"

        # build recording
        self._recorder = recorder
        self._recorder.initialize_module(rec.INITIAL)
        self._recorder(rec.INITIAL, name, config)
        self._recorder.initialize_module(rec.MODEL, self.c.train_epochs)
        self._recorder.initialize_module(rec.WEIGHTS_UPDATE, self.c.train_epochs)
        self._recorder.initialize_module(rec.COMPONENT_UPDATE, self.c.train_epochs, self.c.num_components)
        self._recorder.initialize_module(rec.DRE, self.c.train_epochs)

    @property
    def model(self):
        return self._model

    def train(self):
        for i in range(self.c.train_epochs):
            self._recorder(rec.TRAIN_ITER, i)

            # update density ratio estimator
            dre_steps, loss, acc = self._dre.train(self.model, self.c.dre_batch_size, self.c.dre_num_iters)
            self._recorder(rec.DRE, self._dre, self.model, i, dre_steps)

            # M-Step weights
            if self.model.num_components > 1:
                w_res = self.update_weight()
                self._recorder(rec.WEIGHTS_UPDATE, w_res)

            # M-Step components
            c_res = self.update_components()
            self._recorder(rec.COMPONENT_UPDATE, c_res)

            self._recorder(rec.MODEL, self.model, i)

    def _get_reward(self, samples_as_list):
        rewards = []
        for i, samples in enumerate(samples_as_list):
            rewards.append(self._dre(samples))
        return rewards

    def _get_samples(self):
        samples = []
        for c in self.model.components:
            samples.append(c.sample(self.c.samples_per_component))
        return samples

    def update_components(self):
        samples = self._get_samples()
        rewards = self._get_reward(samples)

        res_list = []

        for i in range(self._model.num_components):
            component = self._model.components[i]
            learner = self._components_learners[i]

            old_dist = Gaussian(component.mean, component.covar)

            surrogate = QuadFunc(1e-12, normalize=True, unnormalize_output=False)
            surrogate.fit(samples[i], rewards[i], None, old_dist.mean, old_dist.chol_covar)

            # This is a numerical thing we did not use in the original paper: We do not undo the output normalization
            # of the regression, this will yield the same solution but the optimal lagrangian multipliers of the
            # MORE dual are scaled, so we also need to adapt the offset. This makes optimizing the dual much more
            # stable and indifferent to initialization
            learner.eta_offset = 1.0 / surrogate.o_std

            new_mean, new_covar = learner.more_step(self.c.component_kl_bound, -1, component, surrogate)
            if learner.success:
                component.update_parameters(new_mean, new_covar)
                res_list.append((1, component.kl(old_dist), component.entropy(), " "))
            else:
                res_list.append((1, 0.0, old_dist.entropy(), "update of component {:d} failed".format(i)))

        return res_list

    def update_weight(self):
        fake_samples = self._get_samples()
        rewards = np.mean(self._get_reward(fake_samples), 1)
        rewards = np.ascontiguousarray(np.concatenate(rewards, -1).astype(np.float64))

        old_dist = Categorical(self._model.weight_distribution.probabilities)

        # -1 as entropy bound is a dummy as entropy is not constraint
        new_probabilities = self._weight_learner.reps_step(self.c.weight_kl_bound, -1, old_dist, rewards)
        if self._weight_learner.success:
            self._model.weight_distribution.probabilities = new_probabilities

        kl = self._model.weight_distribution.kl(old_dist)
        entropy = self._model.weight_distribution.entropy()
        return 1, kl, entropy, " "