Example #1
0
    def __init__(self,
                 state_shape,
                 action_dim,
                 max_action,
                 units=[256, 256],
                 hidden_activation="relu",
                 fix_std=False,
                 const_std=0.1,
                 state_independent_std=False,
                 name='GaussianPolicy'):
        super(GaussianActor, self).__init__(name=name)
        self.dist = DiagonalGaussian(dim=action_dim)
        self._fix_std = fix_std
        self._const_std = const_std
        self._max_action = max_action
        self._state_independent_std = state_independent_std

        self.l1 = Dense(units[0], name="L1", activation=hidden_activation)
        self.l2 = Dense(units[1], name="L2", activation=hidden_activation)
        self.out_mean = Dense(action_dim, name="L_mean")

        if not self._fix_std:
            # 判断是否独立分布
            if self._state_independent_std:
                self.out_log_std = tf.Variable(
                    initial_value=0.5*np.ones(action_dim, dtype=np.float32),
                    dtype=tf.float32, name="logstd"
                )
            else:
                self.out_log_std = Dense(action_dim, name="L_sigma")

        self(tf.constant(np.zeros(shape=(1,)+state_shape, dtype=np.float32)))
Example #2
0
 def __init__(self, observation_space, action_space, task_q, result_q):
     multiprocessing.Process.__init__(self)
     self.task_q = task_q
     self.result_q = result_q
     self.observation_space = observation_space
     self.action_space = action_space
     self.args = pms
     self.baseline = Baseline()
     self.distribution = DiagonalGaussian(pms.action_shape)
    def __init__(self, thread_id):
        print "create worker %d" % (thread_id)
        self.thread_id = thread_id
        self.env = env = Environment(gym.make(pms.environment_name))
        # print("Observation Space", env.observation_space)
        # print("Action Space", env.action_space)
        # print("Action area, high:%f, low%f" % (env.action_space.high, env.action_space.low))
        self.end_count = 0
        self.paths = []
        self.train = True
        self.baseline = Baseline()
        self.storage = Storage(self, self.env, self.baseline)
        self.distribution = DiagonalGaussian(pms.action_shape)

        self.session = self.master.session
        self.init_network()
Example #4
0
 def __init__(self):
     self.env = env = Environment(gym.make(pms.environment_name))
     # if not isinstance(env.observation_space, Box) or \
     #    not isinstance(env.action_space, Discrete):
     #     print("Incompatible spaces.")
     #     exit(-1)
     print("Observation Space", env.observation_space)
     print("Action Space", env.action_space)
     print("Action area, high:%f, low%f" %
           (env.action_space.high, env.action_space.low))
     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1 / 3.0)
     self.session = tf.Session(config=tf.ConfigProto(
         gpu_options=gpu_options))
     self.end_count = 0
     self.paths = []
     self.train = True
     self.baseline = Baseline(self.session)
     self.storage = Storage(self, self.env, self.baseline)
     self.distribution = DiagonalGaussian(pms.action_shape)
     self.init_network()
class TRPOAgentContinousSingleProcess(object):
    def __init__(self, thread_id):
        print "create worker %d" % (thread_id)
        self.thread_id = thread_id
        self.env = env = Environment(gym.make(pms.environment_name))
        # print("Observation Space", env.observation_space)
        # print("Action Space", env.action_space)
        # print("Action area, high:%f, low%f" % (env.action_space.high, env.action_space.low))
        self.end_count = 0
        self.paths = []
        self.train = True
        self.baseline = Baseline()
        self.storage = Storage(self, self.env, self.baseline)
        self.distribution = DiagonalGaussian(pms.action_shape)

        self.session = self.master.session
        self.init_network()

    def init_network(self):
        self.network = NetworkContinous(str(self.thread_id))
        if pms.min_std is not None:
            log_std_var = tf.maximum(self.network.action_dist_logstds_n,
                                     np.log(pms.min_std))
        self.action_dist_stds_n = tf.exp(log_std_var)

        self.old_dist_info_vars = dict(mean=self.network.old_dist_means_n,
                                       log_std=self.network.old_dist_logstds_n)
        self.new_dist_info_vars = dict(
            mean=self.network.action_dist_means_n,
            log_std=self.network.action_dist_logstds_n)
        self.likehood_action_dist = self.distribution.log_likelihood_sym(
            self.network.action_n, self.new_dist_info_vars)
        self.ratio_n = self.distribution.likelihood_ratio_sym(
            self.network.action_n, self.new_dist_info_vars,
            self.old_dist_info_vars)

        surr = -tf.reduce_mean(
            self.ratio_n * self.network.advant)  # Surrogate loss
        kl = tf.reduce_mean(
            self.distribution.kl_sym(self.old_dist_info_vars,
                                     self.new_dist_info_vars))
        ent = self.distribution.entropy(self.old_dist_info_vars)
        # ent = tf.reduce_sum(-p_n * tf.log(p_n + eps)) / Nf
        self.losses = [surr, kl, ent]
        var_list = self.network.var_list
        self.gf = GetFlat(self.session, var_list)  # get theta from var_list
        self.sff = SetFromFlat(self.session,
                               var_list)  # set theta from var_List
        # get g
        self.pg = flatgrad(surr, var_list)
        # get A

        # KL divergence where first arg is fixed
        # replace old->tf.stop_gradient from previous kl
        kl_firstfixed = kl_sym_gradient(self.network.old_dist_means_n,
                                        self.network.old_dist_logstds_n,
                                        self.network.action_dist_means_n,
                                        self.network.action_dist_logstds_n)

        grads = tf.gradients(kl, var_list)
        self.flat_tangent = tf.placeholder(dtype, shape=[None])
        shapes = map(var_shape, var_list)
        start = 0
        tangents = []
        for shape in shapes:
            size = np.prod(shape)
            param = tf.reshape(self.flat_tangent[start:(start + size)], shape)
            tangents.append(param)
            start += size
        self.gvp = [tf.reduce_sum(g * t) for (g, t) in zip(grads, tangents)]
        self.fvp = flatgrad(tf.reduce_sum(self.gvp), var_list)  # get kl''*p
        # self.load_model()

    def get_samples(self, path_number):
        for i in range(pms.paths_number):
            self.storage.get_single_path()

    def get_action(self, obs, *args):
        obs = np.expand_dims(obs, 0)
        # action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
        if pms.use_std_network:
            action_dist_means_n, action_dist_logstds_n = self.session.run(
                [self.action_dist_means_n, self.action_dist_logstds_n],
                {self.obs: obs})
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstds_n[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstds_n[0])
        else:
            action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
            action_dist_means_n = self.network.get_action_dist_means_n(
                self.session, obs)
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstd[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstd[0])

    def run(self):
        self.learn()

    def learn(self):
        start_time = time.time()

        numeptotal = 0
        while True:
            i = 0
            # Generating paths.
            # print("Rollout")
            self.get_samples(pms.paths_number)
            paths = self.storage.get_paths()  # get_paths
            # Computing returns and estimating advantage function.
            sample_data = self.storage.process_paths(paths)

            agent_infos = sample_data["agent_infos"]
            obs_n = sample_data["observations"]
            action_n = sample_data["actions"]
            advant_n = sample_data["advantages"]
            n_samples = len(obs_n)
            inds = np.random.choice(n_samples,
                                    math.floor(n_samples *
                                               pms.subsample_factor),
                                    replace=False)
            obs_n = obs_n[inds]
            action_n = action_n[inds]
            advant_n = advant_n[inds]
            action_dist_means_n = np.array(
                [agent_info["mean"] for agent_info in agent_infos[inds]])
            action_dist_logstds_n = np.array(
                [agent_info["log_std"] for agent_info in agent_infos[inds]])
            feed = {
                self.network.obs: obs_n,
                self.network.advant: advant_n,
                self.network.old_dist_means_n: action_dist_means_n,
                self.network.old_dist_logstds_n: action_dist_logstds_n,
                self.network.action_dist_logstds_n: action_dist_logstds_n,
                self.network.action_n: action_n
            }

            episoderewards = np.array(
                [path["rewards"].sum() for path in paths])
            average_episode_std = np.mean(np.exp(action_dist_logstds_n))

            # print "\n********** Iteration %i ************" % i
            for iter_num_per_train in range(pms.iter_num_per_train):
                # if not self.train:
                #     print("Episode mean: %f" % episoderewards.mean())
                #     self.end_count += 1
                #     if self.end_count > 100:
                #         break
                if self.train:
                    thprev = self.gf()  # get theta_old

                    def fisher_vector_product(p):
                        feed[self.flat_tangent] = p
                        return self.session.run(self.fvp,
                                                feed) + pms.cg_damping * p

                    g = self.session.run(self.pg, feed_dict=feed)
                    stepdir = krylov.cg(fisher_vector_product,
                                        g,
                                        cg_iters=pms.cg_iters)
                    shs = 0.5 * stepdir.dot(
                        fisher_vector_product(stepdir))  # theta
                    fullstep = stepdir * np.sqrt(2.0 * pms.max_kl / shs)
                    neggdotstepdir = -g.dot(stepdir)

                    def loss(th):
                        self.sff(th)
                        return self.session.run(self.losses, feed_dict=feed)

                    surr_prev, kl_prev, ent_prev = loss(thprev)
                    mean_advant = np.mean(advant_n)
                    theta = linesearch(loss, thprev, fullstep, neggdotstepdir)
                    self.sff(theta)
                    surrafter, kloldnew, entnew = self.session.run(
                        self.losses, feed_dict=feed)
                    stats = {}
                    numeptotal += len(episoderewards)
                    stats["average_episode_std"] = average_episode_std
                    stats["sum steps of episodes"] = sample_data[
                        "sum_episode_steps"]
                    stats["Total number of episodes"] = numeptotal
                    stats[
                        "Average sum of rewards per episode"] = episoderewards.mean(
                        )
                    # stats["Entropy"] = entropy
                    # exp = explained_variance(np.array(baseline_n), np.array(returns_n))
                    # stats["Baseline explained"] = exp
                    stats["Time elapsed"] = "%.2f mins" % (
                        (time.time() - start_time) / 60.0)
                    stats["KL between old and new distribution"] = kloldnew
                    stats["Surrogate loss"] = surrafter
                    stats["Surrogate loss prev"] = surr_prev
                    stats["entropy"] = ent_prev
                    stats["mean_advant"] = mean_advant
                    log_data = [
                        average_episode_std,
                        len(episoderewards), numeptotal,
                        episoderewards.mean(), kloldnew, surrafter, surr_prev,
                        surrafter - surr_prev, ent_prev, mean_advant
                    ]
                    self.master.logger.log_row(log_data)
                    # for k, v in stats.iteritems():
                    #     print(k + ": " + " " * (40 - len(k)) + str(v))
                    #     # if entropy != entropy:
                    #     #     exit(-1)
                    #     # if exp > 0.95:
                    #     #     self.train = False
            if self.thread_id == 1:
                self.master.save_model("iter" + str(i))
                print episoderewards.mean()
            i += 1

    def test(self, model_name):
        self.load_model(model_name)
        for i in range(50):
            self.storage.get_single_path()

    def save_model(self, model_name):
        self.saver.save(self.session, "checkpoint/" + model_name + ".ckpt")

    def load_model(self, model_name):
        try:
            self.saver.restore(self.session, model_name)
        except:
            print "load model %s fail" % (model_name)
Example #6
0
class TRPOAgentParallel(multiprocessing.Process):
    def __init__(self, observation_space, action_space, task_q, result_q):
        multiprocessing.Process.__init__(self)
        self.task_q = task_q
        self.result_q = result_q
        self.observation_space = observation_space
        self.action_space = action_space
        self.args = pms
        self.baseline = Baseline()
        self.distribution = DiagonalGaussian(pms.action_shape)
        self.init_logger()

    def init_network(self):
        """
        [input]
        self.obs
        self.action_n
        self.advant
        self.old_dist_means_n
        self.old_dist_logstds_n
        [output]
        self.action_dist_means_n
        self.action_dist_logstds_n
        var_list
        """
        config = tf.ConfigProto(device_count={'GPU': 0})
        self.session = tf.Session(config=config)
        self.net = NetworkContinous("network_continous")
        if pms.min_std is not None:
            log_std_var = tf.maximum(self.net.action_dist_logstds_n,
                                     np.log(pms.min_std))
        self.action_dist_stds_n = tf.exp(log_std_var)
        self.old_dist_info_vars = dict(mean=self.net.old_dist_means_n,
                                       log_std=self.net.old_dist_logstds_n)
        self.new_dist_info_vars = dict(mean=self.net.action_dist_means_n,
                                       log_std=self.net.action_dist_logstds_n)
        self.likehood_action_dist = self.distribution.log_likelihood_sym(
            self.net.action_n, self.new_dist_info_vars)
        self.ratio_n = self.distribution.likelihood_ratio_sym(
            self.net.action_n, self.new_dist_info_vars,
            self.old_dist_info_vars)
        surr = -tf.reduce_mean(
            self.ratio_n * self.net.advant)  # Surrogate loss
        batch_size = tf.shape(self.net.obs)[0]
        batch_size_float = tf.cast(batch_size, tf.float32)
        kl = tf.reduce_mean(
            self.distribution.kl_sym(self.old_dist_info_vars,
                                     self.new_dist_info_vars))
        ent = self.distribution.entropy(self.old_dist_info_vars)
        # ent = tf.reduce_sum(-p_n * tf.log(p_n + eps)) / Nf
        self.losses = [surr, kl, ent]
        var_list = self.net.var_list

        self.gf = GetFlat(var_list)  # get theta from var_list
        self.gf.session = self.session
        self.sff = SetFromFlat(var_list)  # set theta from var_List
        self.sff.session = self.session
        # get g
        self.pg = flatgrad(surr, var_list)
        # get A
        # KL divergence where first arg is fixed
        # replace old->tf.stop_gradient from previous kl
        kl_firstfixed = self.distribution.kl_sym_firstfixed(
            self.new_dist_info_vars) / batch_size_float
        grads = tf.gradients(kl_firstfixed, var_list)
        self.flat_tangent = tf.placeholder(dtype, shape=[None])
        shapes = map(var_shape, var_list)
        start = 0
        tangents = []
        for shape in shapes:
            size = np.prod(shape)
            param = tf.reshape(self.flat_tangent[start:(start + size)], shape)
            tangents.append(param)
            start += size
        self.gvp = [tf.reduce_sum(g * t) for (g, t) in zip(grads, tangents)]
        self.fvp = flatgrad(tf.reduce_sum(self.gvp), var_list)  # get kl''*p
        self.session.run(tf.initialize_all_variables())
        self.saver = tf.train.Saver(max_to_keep=5)

    def init_logger(self):
        head = ["factor", "rewards", "std"]
        self.logger = Logger(head)

    def run(self):
        self.init_network()
        while True:
            paths = self.task_q.get()
            if paths is None:
                # kill the learner
                self.task_q.task_done()
                break
            elif paths == 1:
                # just get params, no learn
                self.task_q.task_done()
                self.result_q.put(self.gf())
            elif paths[0] == 2:
                # adjusting the max KL.
                self.args.max_kl = paths[1]
                if paths[2] == 1:
                    print "saving checkpoint..."
                    self.save_model(pms.environment_name + "-" + str(paths[3]))
                self.task_q.task_done()
            else:
                stats, theta, thprev = self.learn(paths, linear_search=False)
                self.sff(theta)
                self.task_q.task_done()
                self.result_q.put((stats, theta, thprev))
        return

    def learn(self, paths, parallel=False, linear_search=False):
        start_time = time.time()
        sample_data = self.process_paths(paths)
        agent_infos = sample_data["agent_infos"]
        obs_all = sample_data["observations"]
        action_all = sample_data["actions"]
        advant_all = sample_data["advantages"]
        n_samples = len(obs_all)
        batch = int(1 / pms.subsample_factor)
        batch_size = int(math.floor(n_samples * pms.subsample_factor))
        accum_fullstep = 0.0
        for iteration in range(batch):
            print "batch: %d, batch_size: %d" % (iteration + 1, batch_size)
            inds = np.random.choice(n_samples, batch_size, replace=False)
            obs_n = obs_all[inds]
            action_n = action_all[inds]
            advant_n = advant_all[inds]
            action_dist_means_n = np.array(
                [agent_info["mean"] for agent_info in agent_infos[inds]])
            action_dist_logstds_n = np.array(
                [agent_info["log_std"] for agent_info in agent_infos[inds]])
            feed = {
                self.net.obs: obs_n,
                self.net.advant: advant_n,
                self.net.old_dist_means_n: action_dist_means_n,
                self.net.old_dist_logstds_n: action_dist_logstds_n,
                self.net.action_n: action_n
            }

            episoderewards = np.array(
                [path["rewards"].sum() for path in paths])
            thprev = self.gf()  # get theta_old

            def fisher_vector_product(p):
                feed[self.flat_tangent] = p
                return self.session.run(self.fvp, feed) + pms.cg_damping * p

            g = self.session.run(self.pg, feed_dict=feed)
            stepdir = krylov.cg(fisher_vector_product,
                                -g,
                                cg_iters=pms.cg_iters)
            shs = 0.5 * stepdir.dot(fisher_vector_product(stepdir))  # theta
            # if shs<0, then the nan error would appear
            lm = np.sqrt(shs / pms.max_kl)
            fullstep = stepdir / lm
            neggdotstepdir = -g.dot(stepdir)

            def loss(th):
                self.sff(th)
                return self.session.run(self.losses, feed_dict=feed)

            if parallel is True:
                theta = linesearch_parallel(loss, thprev, fullstep,
                                            neggdotstepdir / lm)
            else:
                if linear_search:
                    theta = linesearch(loss, thprev, fullstep,
                                       neggdotstepdir / lm)
                else:
                    theta = thprev + fullstep
            accum_fullstep += (theta - thprev)
        theta = thprev + accum_fullstep * pms.subsample_factor
        stats = {}
        stats["sum steps of episodes"] = sample_data["sum_episode_steps"]
        stats["Average sum of rewards per episode"] = episoderewards.mean()
        stats["surr loss"] = loss(theta)[0]
        stats["Time elapsed"] = "%.2f mins" % (
            (time.time() - start_time) / 60.0)
        self.logger.log_row([
            pms.subsample_factor, stats["Average sum of rewards per episode"],
            self.session.run(self.net.action_dist_logstd_param)[0][0]
        ])
        return stats, theta, thprev

    def process_paths(self, paths):
        sum_episode_steps = 0
        for path in paths:
            sum_episode_steps += path['episode_steps']
            path['baselines'] = self.baseline.predict(path)
            path["returns"] = np.concatenate(
                discount(path["rewards"], pms.discount))
            path["advantages"] = path['returns'] - path['baselines']

        observations = np.concatenate([path["observations"] for path in paths])
        actions = np.concatenate([path["actions"] for path in paths])
        rewards = np.concatenate([path["rewards"] for path in paths])
        advantages = np.concatenate([path["advantages"] for path in paths])
        env_infos = np.concatenate([path["env_infos"] for path in paths])
        agent_infos = np.concatenate([path["agent_infos"] for path in paths])
        if pms.center_adv:
            advantages -= advantages.mean()
            advantages /= (advantages.std() + 1e-8)

        # for some unknown reaseon, it can not be used
        # if pms.positive_adv:
        #     advantages = (advantages - np.min(advantages)) + 1e-8

        # average_discounted_return = \
        #     np.mean([path["returns"][0] for path in paths])
        #
        # undiscounted_returns = [sum(path["rewards"]) for path in paths]

        # ev = self.explained_variance_1d(
        #     np.concatenate(baselines),
        #     np.concatenate(returns)
        # )
        samples_data = dict(observations=observations,
                            actions=actions,
                            rewards=rewards,
                            advantages=advantages,
                            env_infos=env_infos,
                            agent_infos=agent_infos,
                            paths=paths,
                            sum_episode_steps=sum_episode_steps)
        self.baseline.fit(paths)
        return samples_data

    def save_model(self, model_name):
        self.saver.save(self.session, "checkpoint/" + model_name + ".ckpt")

    def load_model(self, model_name):
        try:
            if model_name is not None:
                self.saver.restore(self.session, model_name)
            else:
                self.saver.restore(
                    self.session,
                    tf.train.latest_checkpoint(pms.checkpoint_dir))
        except:
            print "load model %s fail" % (model_name)
Example #7
0
class TRPOAgent(object):
    def __init__(self):
        self.env = env = Environment(gym.make(pms.environment_name))
        # if not isinstance(env.observation_space, Box) or \
        #    not isinstance(env.action_space, Discrete):
        #     print("Incompatible spaces.")
        #     exit(-1)
        print("Observation Space", env.observation_space)
        print("Action Space", env.action_space)
        print("Action area, high:%f, low%f" %
              (env.action_space.high, env.action_space.low))
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1 / 3.0)
        self.session = tf.Session(config=tf.ConfigProto(
            gpu_options=gpu_options))
        self.end_count = 0
        self.paths = []
        self.train = True
        self.baseline = Baseline(self.session)
        self.storage = Storage(self, self.env, self.baseline)
        self.distribution = DiagonalGaussian(pms.action_shape)
        self.init_network()
        if pms.train_flag:
            self.init_logger()

    def init_logger(self):
        head = [
            "average_episode_std", "sum steps episode number"
            "total number of episodes", "Average sum of rewards per episode",
            "KL between old and new distribution", "Surrogate loss",
            "Surrogate loss prev", "ds", "entropy", "mean_advant"
        ]
        self.logger = Logger(head)

    def init_network(self):
        self.obs = obs = tf.placeholder(dtype,
                                        shape=[pms.batch_size, pms.obs_shape],
                                        name="obs")
        self.action_n = tf.placeholder(
            dtype, shape=[pms.batch_size, pms.action_shape], name="action")
        self.advant = tf.placeholder(dtype,
                                     shape=[pms.batch_size],
                                     name="advant")
        self.old_dist_means_n = tf.placeholder(
            dtype,
            shape=[pms.batch_size, pms.action_shape],
            name="oldaction_dist_means")
        self.old_dist_logstds_n = tf.placeholder(
            dtype,
            shape=[pms.batch_size, pms.action_shape],
            name="oldaction_dist_logstds")

        # Create mean network.
        # self.fp_mean1, weight_fp_mean1, bias_fp_mean1 = linear(self.obs, 32, activation_fn=tf.nn.tanh, name="fp_mean1")
        # self.fp_mean2, weight_fp_mean2, bias_fp_mean2 = linear(self.fp_mean1, 32, activation_fn=tf.nn.tanh, name="fp_mean2")
        # self.action_dist_means_n, weight_action_dist_means_n, bias_action_dist_means_n = linear(self.fp_mean2, pms.action_shape, name="action_dist_means")
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(3,
                                                 forget_bias=0.0,
                                                 state_is_tuple=True)
        lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell,
                                                  output_keep_prob=0.5)
        rnn = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * 3, state_is_tuple=True)
        # rnn = tf.nn.rnn_cell.BasicRNNCell(3)
        self.initial_state = state = rnn.zero_state(pms.batch_size, tf.float32)
        # output , state = tf.nn.dynamic_rnn(rnn, self.obs)
        output, state = rnn(self.obs, state)

        print
        self.action_dist_means_n = (pt.wrap(output).fully_connected(
            16,
            activation_fn=tf.nn.tanh,
            init=tf.random_normal_initializer(stddev=1.0),
            bias=False).fully_connected(
                16,
                activation_fn=tf.nn.tanh,
                init=tf.random_normal_initializer(stddev=1.0),
                bias=False).fully_connected(
                    pms.action_shape,
                    init=tf.random_normal_initializer(stddev=1.0),
                    bias=False))

        self.N = tf.shape(obs)[0]
        Nf = tf.cast(self.N, dtype)
        # Create std network.
        if pms.use_std_network:
            self.action_dist_logstds_n = (pt.wrap(self.obs).fully_connected(
                16,
                activation_fn=tf.nn.tanh,
                init=tf.random_normal_initializer(stddev=1.0),
                bias=False).fully_connected(
                    16,
                    activation_fn=tf.nn.tanh,
                    init=tf.random_normal_initializer(stddev=1.0),
                    bias=False).fully_connected(
                        pms.action_shape,
                        init=tf.random_normal_initializer(stddev=1.0),
                        bias=False))
        else:
            self.action_dist_logstds_n = tf.placeholder(
                dtype, shape=[pms.batch_size, pms.action_shape], name="logstd")
        if pms.min_std is not None:
            log_std_var = tf.maximum(self.action_dist_logstds_n,
                                     np.log(pms.min_std))
        self.action_dist_stds_n = tf.exp(log_std_var)

        self.old_dist_info_vars = dict(mean=self.old_dist_means_n,
                                       log_std=self.old_dist_logstds_n)
        self.new_dist_info_vars = dict(mean=self.action_dist_means_n,
                                       log_std=self.action_dist_logstds_n)
        self.likehood_action_dist = self.distribution.log_likelihood_sym(
            self.action_n, self.new_dist_info_vars)
        self.ratio_n = self.distribution.likelihood_ratio_sym(
            self.action_n, self.new_dist_info_vars, self.old_dist_info_vars)

        surr = -tf.reduce_mean(self.ratio_n * self.advant)  # Surrogate loss
        kl = tf.reduce_mean(
            self.distribution.kl_sym(self.old_dist_info_vars,
                                     self.new_dist_info_vars))
        ent = self.distribution.entropy(self.old_dist_info_vars)
        # ent = tf.reduce_sum(-p_n * tf.log(p_n + eps)) / Nf
        self.losses = [surr, kl, ent]

        var_list = tf.trainable_variables()
        self.gf = GetFlat(self.session, var_list)  # get theta from var_list
        self.sff = SetFromFlat(self.session,
                               var_list)  # set theta from var_List
        # get g
        self.pg = flatgrad(surr, var_list)
        # get A

        # KL divergence where first arg is fixed
        # replace old->tf.stop_gradient from previous kl
        kl_firstfixed = kl_sym_gradient(self.old_dist_means_n,
                                        self.old_dist_logstds_n,
                                        self.action_dist_means_n,
                                        self.action_dist_logstds_n)

        grads = tf.gradients(kl, var_list)
        self.flat_tangent = tf.placeholder(dtype, shape=[None])
        shapes = map(var_shape, var_list)
        start = 0
        tangents = []
        for shape in shapes:
            size = np.prod(shape)
            param = tf.reshape(self.flat_tangent[start:(start + size)], shape)
            tangents.append(param)
            start += size
        self.gvp = [tf.reduce_sum(g * t) for (g, t) in zip(grads, tangents)]
        self.fvp = flatgrad(tf.reduce_sum(self.gvp), var_list)  # get kl''*p

        self.session.run(tf.initialize_all_variables())
        self.saver = tf.train.Saver(max_to_keep=10)
        # self.load_model()

    def get_samples(self, path_number):
        # thread_pool = []
        # for i in range(path_number):
        #     thread_pool.append(Rollout(i , self.storage))
        #
        # for thread in thread_pool:
        #     thread.start()
        #
        # for thread in thread_pool:
        #     thread.join()
        for i in range(pms.paths_number):
            self.storage.get_single_path()

    def get_action(self, obs, *args):
        obs = np.expand_dims(obs, 0)
        temp = np.zeros((pms.batch_size, obs.shape[1]))
        for i in range(pms.batch_size):
            temp[i - 1] = obs[0]
        obs = temp
        # action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
        if pms.use_std_network:
            action_dist_means_n, action_dist_logstds_n = self.session.run(
                [self.action_dist_means_n, self.action_dist_logstds_n],
                {self.obs: obs})
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstds_n[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstds_n[0])
        else:
            action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
            action_dist_means_n = self.session.run(self.action_dist_means_n,
                                                   {self.obs: obs})
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstd[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstd[0])

    def learn(self):
        start_time = time.time()
        numeptotal = 0
        i = 0
        while True:
            # Generating paths.
            print("Rollout")
            self.get_samples(pms.paths_number)
            paths = self.storage.get_paths()  # get_paths
            # Computing returns and estimating advantage function.
            sample_data = self.storage.process_paths([paths[0]])

            agent_infos = sample_data["agent_infos"]
            obs_n = sample_data["observations"]
            action_n = sample_data["actions"]
            advant_n = sample_data["advantages"]
            n_samples = len(obs_n)
            inds = np.array(range(0, n_samples))
            obs_n = obs_n[inds]
            action_n = action_n[inds]
            advant_n = advant_n[inds]
            action_dist_means_n = np.array(
                [agent_info["mean"] for agent_info in agent_infos[inds]])
            action_dist_logstds_n = np.array(
                [agent_info["log_std"] for agent_info in agent_infos[inds]])
            feed = {
                self.obs: obs_n,
                self.advant: advant_n,
                self.old_dist_means_n: action_dist_means_n,
                self.old_dist_logstds_n: action_dist_logstds_n,
                self.action_dist_logstds_n: action_dist_logstds_n,
                self.action_n: action_n
            }

            episoderewards = np.array(
                [path["rewards"].sum() for path in paths])
            average_episode_std = np.mean(np.exp(action_dist_logstds_n))

            print "\n********** Iteration %i ************" % i
            for iter_num_per_train in range(pms.iter_num_per_train):
                # if not self.train:
                #     print("Episode mean: %f" % episoderewards.mean())
                #     self.end_count += 1
                #     if self.end_count > 100:
                #         break
                if self.train:
                    thprev = self.gf()  # get theta_old

                    def fisher_vector_product(p):
                        feed[self.flat_tangent] = p
                        return self.session.run(self.fvp,
                                                feed) + pms.cg_damping * p

                    g = self.session.run(self.pg, feed_dict=feed)
                    stepdir = krylov.cg(fisher_vector_product,
                                        g,
                                        cg_iters=pms.cg_iters)
                    shs = 0.5 * stepdir.dot(
                        fisher_vector_product(stepdir))  # theta
                    fullstep = stepdir * np.sqrt(2.0 * pms.max_kl / shs)
                    neggdotstepdir = -g.dot(stepdir)

                    def loss(th):
                        self.sff(th)
                        return self.session.run(self.losses, feed_dict=feed)

                    surr_prev, kl_prev, ent_prev = loss(thprev)
                    mean_advant = np.mean(advant_n)
                    theta = linesearch(loss, thprev, fullstep, neggdotstepdir)
                    self.sff(theta)

                    surrafter, kloldnew, entnew = self.session.run(
                        self.losses, feed_dict=feed)

                    stats = {}

                    numeptotal += len(episoderewards)

                    stats["average_episode_std"] = average_episode_std
                    stats["sum steps of episodes"] = sample_data[
                        "sum_episode_steps"]
                    stats["Total number of episodes"] = numeptotal
                    stats[
                        "Average sum of rewards per episode"] = episoderewards.mean(
                        )
                    # stats["Entropy"] = entropy
                    # exp = explained_variance(np.array(baseline_n), np.array(returns_n))
                    # stats["Baseline explained"] = exp
                    stats["Time elapsed"] = "%.2f mins" % (
                        (time.time() - start_time) / 60.0)
                    stats["KL between old and new distribution"] = kloldnew
                    stats["Surrogate loss"] = surrafter
                    stats["Surrogate loss prev"] = surr_prev
                    stats["entropy"] = ent_prev
                    stats["mean_advant"] = mean_advant
                    log_data = [
                        average_episode_std,
                        len(episoderewards), numeptotal,
                        episoderewards.mean(), kloldnew, surrafter, surr_prev,
                        surrafter - surr_prev, ent_prev, mean_advant
                    ]
                    self.logger.log_row(log_data)
                    for k, v in stats.iteritems():
                        print(k + ": " + " " * (40 - len(k)) + str(v))
                        # if entropy != entropy:
                        #     exit(-1)
                        # if exp > 0.95:
                        #     self.train = False
            self.save_model("iter" + str(i))
            i += 1

    def test(self, model_name):
        self.load_model(model_name)
        if pms.record_movie:
            for i in range(100):
                self.storage.get_single_path()
            self.env.env.monitor.close()
            if pms.upload_to_gym:
                gym.upload("log/trpo",
                           algorithm_id='alg_8BgjkAsQRNiWu11xAhS4Hg',
                           api_key='sk_IJhy3b2QkqL3LWzgBXoVA')
        else:
            for i in range(50):
                self.storage.get_single_path()

    def save_model(self, model_name):
        self.saver.save(self.session, "checkpoint/" + model_name + ".ckpt")

    def load_model(self, model_name):
        try:
            self.saver.restore(self.session, model_name)
        except:
            print "load model %s fail" % (model_name)
Example #8
0
class GaussianActor(tf.keras.Model):
    LOG_SIG_CAP_MAX = 2  # np.e**2 = 7.389
    LOG_SIG_CAP_MIN = -20  # np.e**-20 = 2.061e-9

    def __init__(self,
                 state_shape,
                 action_dim,
                 max_action,
                 units=[256, 256],
                 hidden_activation="relu",
                 fix_std=False,
                 const_std=0.1,
                 state_independent_std=False,
                 name='GaussianPolicy'):
        super(GaussianActor, self).__init__(name=name)
        self.dist = DiagonalGaussian(dim=action_dim)
        self._fix_std = fix_std
        self._const_std = const_std
        self._max_action = max_action
        self._state_independent_std = state_independent_std

        self.l1 = Dense(units[0], name="L1", activation=hidden_activation)
        self.l2 = Dense(units[1], name="L2", activation=hidden_activation)
        self.out_mean = Dense(action_dim, name="L_mean")

        if not self._fix_std:
            # 判断是否独立分布
            if self._state_independent_std:
                self.out_log_std = tf.Variable(
                    initial_value=0.5*np.ones(action_dim, dtype=np.float32),
                    dtype=tf.float32, name="logstd"
                )
            else:
                self.out_log_std = Dense(action_dim, name="L_sigma")

        self(tf.constant(np.zeros(shape=(1,)+state_shape, dtype=np.float32)))

    def _compute_dist(self, states):
        features = self.l1(states)
        features = self.l2(features)
        mean = self.out_mean(features)

        if self._fix_std:
            log_std = tf.ones_like(mean) * tf.math.log(self._const_std)
        else:
            if self._state_independent_std:
                log_std = tf.tile(
                    input=tf.expand_dims(self.out_log_std, axis=0),
                    multiples=[mean.shape[0], 1]
                )
            else:
                log_std = self.out_log_std(features)
                log_std = tf.clip_by_value(log_std, self.LOG_SIG_CAP_MIN, self.LOG_SIG_CAP_MAX)

        return {"mean": mean, "log_std": log_std}

    def call(self, states, test=False):
        param = self._compute_dist(states)
        if test:
            raw_actions = param["mean"]
        else:
            raw_actions = self.dist.sample(param)
        logp_pis = self.dist.log_likelihood(raw_actions, param)
        actions = raw_actions

        return actions * self._max_action, logp_pis, param

    def compute_log_probs(self, states, actions):
        actions /= self._max_action
        param = self._compute_dist(states)
        logp_pis = self.dist.log_likelihood(actions, param)

        return logp_pis

    def compute_entropy(self, states):
        param = self._compute_dist(states)
        return self.dist.entropy(param)
Example #9
0
class TRPOAgentParallel(multiprocessing.Process):
    def __init__(self, observation_space, action_space, task_q, result_q):
        multiprocessing.Process.__init__(self)
        self.task_q = task_q
        self.result_q = result_q
        self.observation_space = observation_space
        self.action_space = action_space
        self.args = pms
        self.baseline = Baseline()
        self.distribution = DiagonalGaussian(pms.action_shape)

    def init_network(self):
        """
            [input]
            self.obs
            self.action_n
            self.advant
            self.old_dist_means_n
            self.old_dist_logstds_n
            [output]
            self.action_dist_means_n
            self.action_dist_logstds_n
            var_list
            """
        self.net = NetworkContinous("network_continous")
        if pms.min_std is not None:
            log_std_var = tf.maximum(self.net.action_dist_logstds_n,
                                     np.log(pms.min_std))
        self.action_dist_stds_n = tf.exp(log_std_var)
        self.old_dist_info_vars = dict(mean=self.net.old_dist_means_n,
                                       log_std=self.net.old_dist_logstds_n)
        self.new_dist_info_vars = dict(mean=self.net.action_dist_means_n,
                                       log_std=self.net.action_dist_logstds_n)
        self.likehood_action_dist = self.distribution.log_likelihood_sym(
            self.net.action_n, self.new_dist_info_vars)
        self.ratio_n = self.distribution.likelihood_ratio_sym(
            self.net.action_n, self.new_dist_info_vars,
            self.old_dist_info_vars)
        surr = -tf.reduce_mean(
            self.ratio_n * self.net.advant)  # Surrogate loss
        batch_size = tf.shape(self.net.obs)[0]
        batch_size_float = tf.cast(batch_size, tf.float32)
        kl = (self.distribution.kl_sym(
            self.old_dist_info_vars,
            self.new_dist_info_vars)) / batch_size_float
        ent = tf.reduce_sum(self.net.action_dist_logstds_n + tf.constant(
            0.5 * np.log(2 * np.pi * np.e), tf.float32)) / batch_size_float
        # ent = tf.reduce_sum(-p_n * tf.log(p_n + eps)) / Nf
        self.losses = [surr, kl, ent]
        var_list = self.net.var_list
        config = tf.ConfigProto(device_count={'GPU': 0})
        self.session = tf.Session(config=config)
        self.gf = GetFlat(var_list)  # get theta from var_list
        self.gf.session = self.session
        self.sff = SetFromFlat(var_list)  # set theta from var_List
        self.sff.session = self.session
        # get g
        self.pg = flatgrad(surr, var_list)
        # get A
        # KL divergence where first arg is fixed
        # replace old->tf.stop_gradient from previous kl
        kl_firstfixed = self.distribution.kl_sym_firstfixed(
            self.new_dist_info_vars) / batch_size_float
        grads = tf.gradients(kl_firstfixed, var_list)
        self.flat_tangent = tf.placeholder(dtype, shape=[None])
        shapes = map(var_shape, var_list)
        start = 0
        tangents = []
        for shape in shapes:
            size = np.prod(shape)
            param = tf.reshape(self.flat_tangent[start:(start + size)], shape)
            tangents.append(param)
            start += size
        self.gvp = [tf.reduce_sum(g * t) for (g, t) in zip(grads, tangents)]
        self.fvp = flatgrad(tf.reduce_sum(self.gvp), var_list)  # get kl''*p
        self.session.run(tf.initialize_all_variables())
        self.saver = tf.train.Saver(max_to_keep=5)

    def run(self):
        self.init_network()
        while True:
            paths = self.task_q.get()
            if paths is None:
                # kill the learner
                self.task_q.task_done()
                break
            elif paths == 1:
                # just get params, no learn
                self.task_q.task_done()
                self.result_q.put(self.gf())
            elif paths[0] == 2:
                # adjusting the max KL.
                self.args.max_kl = paths[1]
                if paths[2] == 1:
                    print "saving checkpoint..."
                    self.save_model(pms.environment_name + "-" + str(paths[3]))
                self.task_q.task_done()
            else:
                stats, theta, thprev = self.learn(paths)
                self.sff(theta)
                self.task_q.task_done()
                self.result_q.put((stats, theta, thprev))
        return

    def learn(self, paths, parallel=False, linear_search=False):
        # Generating paths.
        start_time = time.time()
        # Computing returns and estimating advantage function.
        sample_data = self.process_paths(paths)
        agent_infos = sample_data["agent_infos"]
        obs_n = sample_data["observations"]
        action_n = sample_data["actions"]
        advant_n = sample_data["advantages"]
        n_samples = len(obs_n)
        inds = np.random.choice(
            n_samples,
            int(math.floor(n_samples * pms.subsample_factor)),
            replace=False)
        # inds = range(n_samples)
        obs_n = obs_n[inds]
        action_n = action_n[inds]
        advant_n = advant_n[inds]
        action_dist_means_n = np.array(
            [agent_info["mean"] for agent_info in agent_infos[inds]])
        action_dist_logstds_n = np.array(
            [agent_info["log_std"] for agent_info in agent_infos[inds]])
        feed = {
            self.net.obs: obs_n,
            self.net.advant: advant_n,
            self.net.old_dist_means_n: action_dist_means_n,
            self.net.old_dist_logstds_n: action_dist_logstds_n,
            self.net.action_n: action_n
        }

        episoderewards = np.array([path["rewards"].sum() for path in paths])
        thprev = self.gf()  # get theta_old

        def fisher_vector_product(p):
            feed[self.flat_tangent] = p
            return self.session.run(self.fvp, feed) + pms.cg_damping * p

        g = self.session.run(self.pg, feed_dict=feed)
        stepdir = krylov.cg(fisher_vector_product, -g, cg_iters=pms.cg_iters)
        shs = 0.5 * stepdir.dot(fisher_vector_product(stepdir))  # theta
        # if shs<0, then the nan error would appear
        lm = np.sqrt(shs / pms.max_kl)
        fullstep = stepdir / lm
        neggdotstepdir = -g.dot(stepdir)

        def loss(th):
            self.sff(th)
            return self.session.run(self.losses, feed_dict=feed)

        if parallel is True:
            theta = linesearch_parallel(loss, thprev, fullstep,
                                        neggdotstepdir / lm)
        else:
            if linear_search:
                theta = linesearch(loss, thprev, fullstep, neggdotstepdir / lm)
            else:
                theta = thprev + fullstep
                if math.isnan(theta.mean()):
                    print shs is None
                    theta = thprev
        stats = {}
        stats["sum steps of episodes"] = sample_data["sum_episode_steps"]
        stats["Average sum of rewards per episode"] = episoderewards.mean()
        stats["Time elapsed"] = "%.2f mins" % (
            (time.time() - start_time) / 60.0)
        return stats, theta, thprev

    def process_paths(self, paths):
        sum_episode_steps = 0
        for path in paths:
            sum_episode_steps += path['episode_steps']
            # r_t+V(S_{t+1})-V(S_t) = returns-baseline
            # path_baselines = np.append(self.baseline.predict(path) , 0)
            # # r_t+V(S_{t+1})-V(S_t) = returns-baseline
            # path["advantages"] = np.concatenate(path["rewards"]) + \
            #          pms.discount * path_baselines[1:] - \
            #          path_baselines[:-1]
            # path["returns"] = np.concatenate(discount(path["rewards"], pms.discount))
            path_baselines = np.append(self.baseline.predict(path), 0)
            deltas = np.concatenate(path["rewards"]) + \
                     pms.discount * path_baselines[1:] - \
                     path_baselines[:-1]
            path["advantages"] = discount(deltas,
                                          pms.discount * pms.gae_lambda)
            path["returns"] = np.concatenate(
                discount(path["rewards"], pms.discount))
            path["advantages"] = path["returns"]

        observations = np.concatenate([path["observations"] for path in paths])
        actions = np.concatenate([path["actions"] for path in paths])
        rewards = np.concatenate([path["rewards"] for path in paths])
        advantages = np.concatenate([path["advantages"] for path in paths])
        env_infos = np.concatenate([path["env_infos"] for path in paths])
        agent_infos = np.concatenate([path["agent_infos"] for path in paths])
        if pms.center_adv:
            advantages -= np.mean(advantages)
            advantages /= (advantages.std() + 1e-8)
        samples_data = dict(observations=observations,
                            actions=actions,
                            rewards=rewards,
                            advantages=advantages,
                            env_infos=env_infos,
                            agent_infos=agent_infos,
                            paths=paths,
                            sum_episode_steps=sum_episode_steps)
        self.baseline.fit(paths)
        return samples_data

# class TRPOAgentParallel(multiprocessing.Process):
#     def __init__(self , observation_space , action_space , task_q , result_q):
#         multiprocessing.Process.__init__(self)
#         self.task_q = task_q
#         self.result_q = result_q
#         self.observation_space = observation_space
#         self.action_space = action_space
#         self.args = pms
#
#     def run(self):
#         env = Environment(gym.make(pms.environment_name))
#         self.agent = TRPOAgent(env)
#         # self.agent.init_network()
#         while True:
#             paths = self.task_q.get()
#             if paths is None:
#                 # kill the learner
#                 self.task_q.task_done()
#                 break
#             elif paths == 1:
#                 # just get params, no learn
#                 self.task_q.task_done()
#                 self.result_q.put(self.agent.gf())
#             elif paths[0] == 2:
#                 # adjusting the max KL.
#                 self.args.max_kl = paths[1]
#                 self.task_q.task_done()
#             else:
#                 stats, theta, thprev = self.agent.train_paths(paths, parallel=False, linear_search=True)
#                 self.agent.sff(theta)
#                 self.task_q.task_done()
#                 self.result_q.put((stats, theta, thprev))
#         return

    def save_model(self, model_name):
        self.saver.save(self.session, "checkpoint/" + model_name + ".ckpt")

    def load_model(self, model_name):
        try:
            if model_name is not None:
                self.saver.restore(self.session, model_name)
            else:
                self.saver.restore(
                    self.session,
                    tf.train.latest_checkpoint(pms.checkpoint_dir))
        except:
            print "load model %s fail" % (model_name)
    def __init__(self,
                 input_shape,
                 output_size,
                 hidden_sizes=(32, 32),
                 learn_std=True,
                 init_std=1.0,
                 adaptive_std=False,
                 std_hidden_sizes=(32, 32),
                 min_std=1e-6,
                 std_parametrization='exp',
                 hidden_nonlinearity=tf.nn.tanh,
                 output_nonlinearity=None,
                 std_hidden_nonlinearity=tf.nn.tanh):

        self.input_shape = input_shape
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.std_parametrization = std_parametrization
        self.locals = locals()

        self.distribution = DiagonalGaussian(output_size)
        self.params = []

        with tf.variable_scope("policy"):
            # Mean network
            self.mean_mlp = MLP(input_shape=input_shape,
                                output_size=output_size,
                                hidden_sizes=hidden_sizes,
                                hidden_nonlinearity=hidden_nonlinearity,
                                output_nonlinearity=output_nonlinearity,
                                name='mean')

            self.x = self.mean_mlp.get_input_layer()
            self.mean = self.mean_mlp.get_output_layer()
            self.params += self.mean_mlp.get_params()

            # Var network
            if adaptive_std:
                self.var_mlp = MLP(input_shape=input_shape,
                                   output_size=output_size,
                                   hidden_sizes=std_hidden_sizes,
                                   hidden_nonlinearity=std_hidden_nonlinearity,
                                   output_nonlinearity=None,
                                   input_layer=self.x,
                                   name='var')
                self.log_var = self.var_mlp.get_output_layer()
                self.params += self.var_mlp.get_params()
            else:
                if std_parametrization == 'exp':
                    init_std_param = numpy.log(init_std)
                elif std_parametrization == 'softplus':
                    init_std_param = numpy.log(numpy.exp(init_std) - 1)
                else:
                    raise NotImplementedError

                with tf.variable_scope('var'):
                    self.log_var = tf.get_variable(
                        name='v',
                        shape=(output_size, ),
                        dtype=tf.float32,
                        initializer=tf.constant_initializer(init_std_param,
                                                            dtype=tf.float32),
                        trainable=learn_std)
                    self.log_var = tf.tile(tf.reshape(self.log_var,
                                                      shape=(-1, output_size)),
                                           multiples=[tf.shape(self.x)[0], 1])

                    self.params += tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES,
                        tf.get_variable_scope().name)

            if std_parametrization == 'exp':
                min_std_param = numpy.log(min_std)
            elif std_parametrization == 'softplus':
                min_std_param = numpy.log(numpy.exp(min_std) - 1)
            else:
                raise NotImplementedError

            self.log_var = tf.maximum(self.log_var, min_std_param)
            if self.std_parametrization == 'softplus':
                self.log_var = tf.log(tf.log(1. + tf.exp(self.log_var)))