Exemplo n.º 1
0
class TRPOAgentBase(object):
    def __init__(self, env):
        self.env = env
        # if not isinstance(env.observation_space, Box) or \
        #    not isinstance(env.action_space, Discrete):
        #     print("Incompatible spaces.")
        #     exit(-1)
        print("Observation Space", env.observation_space)
        print("Action Space", env.action_space)
        print("Action area, high:%f, low%f" % (env.action_space.high, env.action_space.low))
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1 / 3.0)
        self.session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        self.end_count = 0
        self.paths = []
        self.train = True
        self.baseline = Baseline()
        self.storage = Storage(self, self.env, self.baseline)
        self.distribution = DiagonalGaussian(pms.action_shape)
        self.net = None

    # def init_logger(self):
    #     head = ["average_episode_std" , "sum steps episode number" "total number of episodes" ,
    #             "Average sum of rewards per episode" ,
    #             "KL between old and new distribution" , "Surrogate loss" , "Surrogate loss prev" , "ds" , "entropy" ,
    #             "mean_advant"]
    #     self.logger = Logger(head)

    def init_network(self):
        raise NotImplementedError

    def get_samples(self, path_number):
        for i in range(path_number):
            self.storage.get_single_path()

    def get_action(self, obs, *args):
        if self.net==None:
            raise NameError("network have not been defined")
        obs = np.expand_dims(obs, 0)
        # action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
        action_dist_means_n, action_dist_stds_n = self.session.run([self.net.action_dist_means_n, self.action_dist_stds_n],
                                               {self.net.obs: obs})
        if pms.train_flag:
            rnd = np.random.normal(size=action_dist_means_n[0].shape)
            action = rnd * action_dist_stds_n[0] + action_dist_means_n[0]
        else:
            action = action_dist_means_n[0]
        # action = np.clip(action, pms.min_a, pms.max_a)
        return action, dict(mean=action_dist_means_n[0], log_std=action_dist_stds_n[0])

    def train_mini_batch(self, parallel=False, linear_search=True):
        # Generating paths.
        self.get_samples(pms.paths_number)
        paths = self.storage.get_paths()  # get_paths
        # Computing returns and estimating advantage function.
        return self.train_paths(paths, parallel=parallel, linear_search=linear_search)


    def train_paths(self, paths, parallel=False, linear_search=True):
        start_time = time.time()
        sample_data = self.storage.process_paths(paths)
        agent_infos = sample_data["agent_infos"]
        obs_n = sample_data["observations"]
        action_n = sample_data["actions"]
        advant_n = sample_data["advantages"]
        n_samples = len(obs_n)
        inds = np.random.choice(n_samples , int(math.floor(n_samples * pms.subsample_factor)) , replace=False)
        # inds = range(n_samples)
        obs_n = obs_n[inds]
        action_n = action_n[inds]
        advant_n = advant_n[inds]
        action_dist_means_n = np.array([agent_info["mean"] for agent_info in agent_infos[inds]])
        action_dist_logstds_n = np.array([agent_info["log_std"] for agent_info in agent_infos[inds]])
        feed = {self.net.obs: obs_n ,
                self.net.advant: advant_n ,
                self.net.old_dist_means_n: action_dist_means_n ,
                self.net.old_dist_logstds_n: action_dist_logstds_n ,
                self.net.action_n: action_n
                }

        episoderewards = np.array([path["rewards"].sum() for path in paths])
        thprev = self.gf()  # get theta_old

        def fisher_vector_product(p):
            feed[self.flat_tangent] = p
            return self.session.run(self.fvp , feed) + pms.cg_damping * p

        g = self.session.run(self.pg , feed_dict=feed)
        stepdir = krylov.cg(fisher_vector_product , -g , cg_iters=pms.cg_iters)
        shs = 0.5 * stepdir.dot(fisher_vector_product(stepdir))  # theta
        # if shs<0, then the nan error would appear
        lm = np.sqrt(shs / pms.max_kl)
        fullstep = stepdir / lm
        neggdotstepdir = -g.dot(stepdir)

        def loss(th):
            self.sff(th)
            return self.session.run(self.losses , feed_dict=feed)

        if parallel is True:
            theta = linesearch_parallel(loss , thprev , fullstep , neggdotstepdir / lm)
        else:
            if linear_search:
                theta = linesearch(loss , thprev , fullstep , neggdotstepdir / lm)
            else:
                theta = thprev + fullstep
                if math.isnan(theta.mean()):
                    print shs is None
                    theta = thprev
        stats = {}
        stats["sum steps of episodes"] = sample_data["sum_episode_steps"]
        stats["Average sum of rewards per episode"] = episoderewards.mean()
        stats["Time elapsed"] = "%.2f mins" % ((time.time() - start_time) / 60.0)
        return stats , theta , thprev

    def learn(self):
        raise NotImplementedError

    def test(self, model_name):
        self.load_model(model_name)
        if pms.record_movie:
            for i in range(100):
                self.storage.get_single_path()
            self.env.env.monitor.close()
            if pms.upload_to_gym:
                gym.upload("log/trpo",algorithm_id='alg_8BgjkAsQRNiWu11xAhS4Hg', api_key='sk_IJhy3b2QkqL3LWzgBXoVA')
        else:
            for i in range(50):
                self.storage.get_single_path()

    def save_model(self, model_name):
        self.saver.save(self.session, "checkpoint/" + model_name + ".ckpt")

    def load_model(self, model_name):
        try:
            if model_name is not None:
                self.saver.restore(self.session, model_name)
            else:
                self.saver.restore(self.session, tf.train.latest_checkpoint(pms.checkpoint_dir))
        except:
            print "load model %s fail" % (model_name)
class TRPOAgentContinousSingleProcess(object):
    def __init__(self, thread_id):
        print "create worker %d" % (thread_id)
        self.thread_id = thread_id
        self.env = env = Environment(gym.make(pms.environment_name))
        # print("Observation Space", env.observation_space)
        # print("Action Space", env.action_space)
        # print("Action area, high:%f, low%f" % (env.action_space.high, env.action_space.low))
        self.end_count = 0
        self.paths = []
        self.train = True
        self.baseline = Baseline()
        self.storage = Storage(self, self.env, self.baseline)
        self.distribution = DiagonalGaussian(pms.action_shape)

        self.session = self.master.session
        self.init_network()

    def init_network(self):
        self.network = NetworkContinous(str(self.thread_id))
        if pms.min_std is not None:
            log_std_var = tf.maximum(self.network.action_dist_logstds_n,
                                     np.log(pms.min_std))
        self.action_dist_stds_n = tf.exp(log_std_var)

        self.old_dist_info_vars = dict(mean=self.network.old_dist_means_n,
                                       log_std=self.network.old_dist_logstds_n)
        self.new_dist_info_vars = dict(
            mean=self.network.action_dist_means_n,
            log_std=self.network.action_dist_logstds_n)
        self.likehood_action_dist = self.distribution.log_likelihood_sym(
            self.network.action_n, self.new_dist_info_vars)
        self.ratio_n = self.distribution.likelihood_ratio_sym(
            self.network.action_n, self.new_dist_info_vars,
            self.old_dist_info_vars)

        surr = -tf.reduce_mean(
            self.ratio_n * self.network.advant)  # Surrogate loss
        kl = tf.reduce_mean(
            self.distribution.kl_sym(self.old_dist_info_vars,
                                     self.new_dist_info_vars))
        ent = self.distribution.entropy(self.old_dist_info_vars)
        # ent = tf.reduce_sum(-p_n * tf.log(p_n + eps)) / Nf
        self.losses = [surr, kl, ent]
        var_list = self.network.var_list
        self.gf = GetFlat(self.session, var_list)  # get theta from var_list
        self.sff = SetFromFlat(self.session,
                               var_list)  # set theta from var_List
        # get g
        self.pg = flatgrad(surr, var_list)
        # get A

        # KL divergence where first arg is fixed
        # replace old->tf.stop_gradient from previous kl
        kl_firstfixed = kl_sym_gradient(self.network.old_dist_means_n,
                                        self.network.old_dist_logstds_n,
                                        self.network.action_dist_means_n,
                                        self.network.action_dist_logstds_n)

        grads = tf.gradients(kl, var_list)
        self.flat_tangent = tf.placeholder(dtype, shape=[None])
        shapes = map(var_shape, var_list)
        start = 0
        tangents = []
        for shape in shapes:
            size = np.prod(shape)
            param = tf.reshape(self.flat_tangent[start:(start + size)], shape)
            tangents.append(param)
            start += size
        self.gvp = [tf.reduce_sum(g * t) for (g, t) in zip(grads, tangents)]
        self.fvp = flatgrad(tf.reduce_sum(self.gvp), var_list)  # get kl''*p
        # self.load_model()

    def get_samples(self, path_number):
        for i in range(pms.paths_number):
            self.storage.get_single_path()

    def get_action(self, obs, *args):
        obs = np.expand_dims(obs, 0)
        # action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
        if pms.use_std_network:
            action_dist_means_n, action_dist_logstds_n = self.session.run(
                [self.action_dist_means_n, self.action_dist_logstds_n],
                {self.obs: obs})
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstds_n[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstds_n[0])
        else:
            action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
            action_dist_means_n = self.network.get_action_dist_means_n(
                self.session, obs)
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstd[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstd[0])

    def run(self):
        self.learn()

    def learn(self):
        start_time = time.time()

        numeptotal = 0
        while True:
            i = 0
            # Generating paths.
            # print("Rollout")
            self.get_samples(pms.paths_number)
            paths = self.storage.get_paths()  # get_paths
            # Computing returns and estimating advantage function.
            sample_data = self.storage.process_paths(paths)

            agent_infos = sample_data["agent_infos"]
            obs_n = sample_data["observations"]
            action_n = sample_data["actions"]
            advant_n = sample_data["advantages"]
            n_samples = len(obs_n)
            inds = np.random.choice(n_samples,
                                    math.floor(n_samples *
                                               pms.subsample_factor),
                                    replace=False)
            obs_n = obs_n[inds]
            action_n = action_n[inds]
            advant_n = advant_n[inds]
            action_dist_means_n = np.array(
                [agent_info["mean"] for agent_info in agent_infos[inds]])
            action_dist_logstds_n = np.array(
                [agent_info["log_std"] for agent_info in agent_infos[inds]])
            feed = {
                self.network.obs: obs_n,
                self.network.advant: advant_n,
                self.network.old_dist_means_n: action_dist_means_n,
                self.network.old_dist_logstds_n: action_dist_logstds_n,
                self.network.action_dist_logstds_n: action_dist_logstds_n,
                self.network.action_n: action_n
            }

            episoderewards = np.array(
                [path["rewards"].sum() for path in paths])
            average_episode_std = np.mean(np.exp(action_dist_logstds_n))

            # print "\n********** Iteration %i ************" % i
            for iter_num_per_train in range(pms.iter_num_per_train):
                # if not self.train:
                #     print("Episode mean: %f" % episoderewards.mean())
                #     self.end_count += 1
                #     if self.end_count > 100:
                #         break
                if self.train:
                    thprev = self.gf()  # get theta_old

                    def fisher_vector_product(p):
                        feed[self.flat_tangent] = p
                        return self.session.run(self.fvp,
                                                feed) + pms.cg_damping * p

                    g = self.session.run(self.pg, feed_dict=feed)
                    stepdir = krylov.cg(fisher_vector_product,
                                        g,
                                        cg_iters=pms.cg_iters)
                    shs = 0.5 * stepdir.dot(
                        fisher_vector_product(stepdir))  # theta
                    fullstep = stepdir * np.sqrt(2.0 * pms.max_kl / shs)
                    neggdotstepdir = -g.dot(stepdir)

                    def loss(th):
                        self.sff(th)
                        return self.session.run(self.losses, feed_dict=feed)

                    surr_prev, kl_prev, ent_prev = loss(thprev)
                    mean_advant = np.mean(advant_n)
                    theta = linesearch(loss, thprev, fullstep, neggdotstepdir)
                    self.sff(theta)
                    surrafter, kloldnew, entnew = self.session.run(
                        self.losses, feed_dict=feed)
                    stats = {}
                    numeptotal += len(episoderewards)
                    stats["average_episode_std"] = average_episode_std
                    stats["sum steps of episodes"] = sample_data[
                        "sum_episode_steps"]
                    stats["Total number of episodes"] = numeptotal
                    stats[
                        "Average sum of rewards per episode"] = episoderewards.mean(
                        )
                    # stats["Entropy"] = entropy
                    # exp = explained_variance(np.array(baseline_n), np.array(returns_n))
                    # stats["Baseline explained"] = exp
                    stats["Time elapsed"] = "%.2f mins" % (
                        (time.time() - start_time) / 60.0)
                    stats["KL between old and new distribution"] = kloldnew
                    stats["Surrogate loss"] = surrafter
                    stats["Surrogate loss prev"] = surr_prev
                    stats["entropy"] = ent_prev
                    stats["mean_advant"] = mean_advant
                    log_data = [
                        average_episode_std,
                        len(episoderewards), numeptotal,
                        episoderewards.mean(), kloldnew, surrafter, surr_prev,
                        surrafter - surr_prev, ent_prev, mean_advant
                    ]
                    self.master.logger.log_row(log_data)
                    # for k, v in stats.iteritems():
                    #     print(k + ": " + " " * (40 - len(k)) + str(v))
                    #     # if entropy != entropy:
                    #     #     exit(-1)
                    #     # if exp > 0.95:
                    #     #     self.train = False
            if self.thread_id == 1:
                self.master.save_model("iter" + str(i))
                print episoderewards.mean()
            i += 1

    def test(self, model_name):
        self.load_model(model_name)
        for i in range(50):
            self.storage.get_single_path()

    def save_model(self, model_name):
        self.saver.save(self.session, "checkpoint/" + model_name + ".ckpt")

    def load_model(self, model_name):
        try:
            self.saver.restore(self.session, model_name)
        except:
            print "load model %s fail" % (model_name)
Exemplo n.º 3
0
class TRPOAgent(object):
    def __init__(self):
        self.env = env = Environment(gym.make(pms.environment_name))
        # if not isinstance(env.observation_space, Box) or \
        #    not isinstance(env.action_space, Discrete):
        #     print("Incompatible spaces.")
        #     exit(-1)
        print("Observation Space", env.observation_space)
        print("Action Space", env.action_space)
        print("Action area, high:%f, low%f" %
              (env.action_space.high, env.action_space.low))
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1 / 3.0)
        self.session = tf.Session(config=tf.ConfigProto(
            gpu_options=gpu_options))
        self.end_count = 0
        self.paths = []
        self.train = True
        self.baseline = Baseline(self.session)
        self.storage = Storage(self, self.env, self.baseline)
        self.distribution = DiagonalGaussian(pms.action_shape)
        self.init_network()
        if pms.train_flag:
            self.init_logger()

    def init_logger(self):
        head = [
            "average_episode_std", "sum steps episode number"
            "total number of episodes", "Average sum of rewards per episode",
            "KL between old and new distribution", "Surrogate loss",
            "Surrogate loss prev", "ds", "entropy", "mean_advant"
        ]
        self.logger = Logger(head)

    def init_network(self):
        self.obs = obs = tf.placeholder(dtype,
                                        shape=[pms.batch_size, pms.obs_shape],
                                        name="obs")
        self.action_n = tf.placeholder(
            dtype, shape=[pms.batch_size, pms.action_shape], name="action")
        self.advant = tf.placeholder(dtype,
                                     shape=[pms.batch_size],
                                     name="advant")
        self.old_dist_means_n = tf.placeholder(
            dtype,
            shape=[pms.batch_size, pms.action_shape],
            name="oldaction_dist_means")
        self.old_dist_logstds_n = tf.placeholder(
            dtype,
            shape=[pms.batch_size, pms.action_shape],
            name="oldaction_dist_logstds")

        # Create mean network.
        # self.fp_mean1, weight_fp_mean1, bias_fp_mean1 = linear(self.obs, 32, activation_fn=tf.nn.tanh, name="fp_mean1")
        # self.fp_mean2, weight_fp_mean2, bias_fp_mean2 = linear(self.fp_mean1, 32, activation_fn=tf.nn.tanh, name="fp_mean2")
        # self.action_dist_means_n, weight_action_dist_means_n, bias_action_dist_means_n = linear(self.fp_mean2, pms.action_shape, name="action_dist_means")
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(3,
                                                 forget_bias=0.0,
                                                 state_is_tuple=True)
        lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell,
                                                  output_keep_prob=0.5)
        rnn = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * 3, state_is_tuple=True)
        # rnn = tf.nn.rnn_cell.BasicRNNCell(3)
        self.initial_state = state = rnn.zero_state(pms.batch_size, tf.float32)
        # output , state = tf.nn.dynamic_rnn(rnn, self.obs)
        output, state = rnn(self.obs, state)

        print
        self.action_dist_means_n = (pt.wrap(output).fully_connected(
            16,
            activation_fn=tf.nn.tanh,
            init=tf.random_normal_initializer(stddev=1.0),
            bias=False).fully_connected(
                16,
                activation_fn=tf.nn.tanh,
                init=tf.random_normal_initializer(stddev=1.0),
                bias=False).fully_connected(
                    pms.action_shape,
                    init=tf.random_normal_initializer(stddev=1.0),
                    bias=False))

        self.N = tf.shape(obs)[0]
        Nf = tf.cast(self.N, dtype)
        # Create std network.
        if pms.use_std_network:
            self.action_dist_logstds_n = (pt.wrap(self.obs).fully_connected(
                16,
                activation_fn=tf.nn.tanh,
                init=tf.random_normal_initializer(stddev=1.0),
                bias=False).fully_connected(
                    16,
                    activation_fn=tf.nn.tanh,
                    init=tf.random_normal_initializer(stddev=1.0),
                    bias=False).fully_connected(
                        pms.action_shape,
                        init=tf.random_normal_initializer(stddev=1.0),
                        bias=False))
        else:
            self.action_dist_logstds_n = tf.placeholder(
                dtype, shape=[pms.batch_size, pms.action_shape], name="logstd")
        if pms.min_std is not None:
            log_std_var = tf.maximum(self.action_dist_logstds_n,
                                     np.log(pms.min_std))
        self.action_dist_stds_n = tf.exp(log_std_var)

        self.old_dist_info_vars = dict(mean=self.old_dist_means_n,
                                       log_std=self.old_dist_logstds_n)
        self.new_dist_info_vars = dict(mean=self.action_dist_means_n,
                                       log_std=self.action_dist_logstds_n)
        self.likehood_action_dist = self.distribution.log_likelihood_sym(
            self.action_n, self.new_dist_info_vars)
        self.ratio_n = self.distribution.likelihood_ratio_sym(
            self.action_n, self.new_dist_info_vars, self.old_dist_info_vars)

        surr = -tf.reduce_mean(self.ratio_n * self.advant)  # Surrogate loss
        kl = tf.reduce_mean(
            self.distribution.kl_sym(self.old_dist_info_vars,
                                     self.new_dist_info_vars))
        ent = self.distribution.entropy(self.old_dist_info_vars)
        # ent = tf.reduce_sum(-p_n * tf.log(p_n + eps)) / Nf
        self.losses = [surr, kl, ent]

        var_list = tf.trainable_variables()
        self.gf = GetFlat(self.session, var_list)  # get theta from var_list
        self.sff = SetFromFlat(self.session,
                               var_list)  # set theta from var_List
        # get g
        self.pg = flatgrad(surr, var_list)
        # get A

        # KL divergence where first arg is fixed
        # replace old->tf.stop_gradient from previous kl
        kl_firstfixed = kl_sym_gradient(self.old_dist_means_n,
                                        self.old_dist_logstds_n,
                                        self.action_dist_means_n,
                                        self.action_dist_logstds_n)

        grads = tf.gradients(kl, var_list)
        self.flat_tangent = tf.placeholder(dtype, shape=[None])
        shapes = map(var_shape, var_list)
        start = 0
        tangents = []
        for shape in shapes:
            size = np.prod(shape)
            param = tf.reshape(self.flat_tangent[start:(start + size)], shape)
            tangents.append(param)
            start += size
        self.gvp = [tf.reduce_sum(g * t) for (g, t) in zip(grads, tangents)]
        self.fvp = flatgrad(tf.reduce_sum(self.gvp), var_list)  # get kl''*p

        self.session.run(tf.initialize_all_variables())
        self.saver = tf.train.Saver(max_to_keep=10)
        # self.load_model()

    def get_samples(self, path_number):
        # thread_pool = []
        # for i in range(path_number):
        #     thread_pool.append(Rollout(i , self.storage))
        #
        # for thread in thread_pool:
        #     thread.start()
        #
        # for thread in thread_pool:
        #     thread.join()
        for i in range(pms.paths_number):
            self.storage.get_single_path()

    def get_action(self, obs, *args):
        obs = np.expand_dims(obs, 0)
        temp = np.zeros((pms.batch_size, obs.shape[1]))
        for i in range(pms.batch_size):
            temp[i - 1] = obs[0]
        obs = temp
        # action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
        if pms.use_std_network:
            action_dist_means_n, action_dist_logstds_n = self.session.run(
                [self.action_dist_means_n, self.action_dist_logstds_n],
                {self.obs: obs})
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstds_n[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstds_n[0])
        else:
            action_dist_logstd = np.expand_dims([np.log(pms.std)], 0)
            action_dist_means_n = self.session.run(self.action_dist_means_n,
                                                   {self.obs: obs})
            if pms.train_flag:
                rnd = np.random.normal(size=action_dist_means_n[0].shape)
                action = rnd * np.exp(
                    action_dist_logstd[0]) + action_dist_means_n[0]
            else:
                action = action_dist_means_n[0]
            # action = np.clip(action, pms.min_a, pms.max_a)
            return action, dict(mean=action_dist_means_n[0],
                                log_std=action_dist_logstd[0])

    def learn(self):
        start_time = time.time()
        numeptotal = 0
        i = 0
        while True:
            # Generating paths.
            print("Rollout")
            self.get_samples(pms.paths_number)
            paths = self.storage.get_paths()  # get_paths
            # Computing returns and estimating advantage function.
            sample_data = self.storage.process_paths([paths[0]])

            agent_infos = sample_data["agent_infos"]
            obs_n = sample_data["observations"]
            action_n = sample_data["actions"]
            advant_n = sample_data["advantages"]
            n_samples = len(obs_n)
            inds = np.array(range(0, n_samples))
            obs_n = obs_n[inds]
            action_n = action_n[inds]
            advant_n = advant_n[inds]
            action_dist_means_n = np.array(
                [agent_info["mean"] for agent_info in agent_infos[inds]])
            action_dist_logstds_n = np.array(
                [agent_info["log_std"] for agent_info in agent_infos[inds]])
            feed = {
                self.obs: obs_n,
                self.advant: advant_n,
                self.old_dist_means_n: action_dist_means_n,
                self.old_dist_logstds_n: action_dist_logstds_n,
                self.action_dist_logstds_n: action_dist_logstds_n,
                self.action_n: action_n
            }

            episoderewards = np.array(
                [path["rewards"].sum() for path in paths])
            average_episode_std = np.mean(np.exp(action_dist_logstds_n))

            print "\n********** Iteration %i ************" % i
            for iter_num_per_train in range(pms.iter_num_per_train):
                # if not self.train:
                #     print("Episode mean: %f" % episoderewards.mean())
                #     self.end_count += 1
                #     if self.end_count > 100:
                #         break
                if self.train:
                    thprev = self.gf()  # get theta_old

                    def fisher_vector_product(p):
                        feed[self.flat_tangent] = p
                        return self.session.run(self.fvp,
                                                feed) + pms.cg_damping * p

                    g = self.session.run(self.pg, feed_dict=feed)
                    stepdir = krylov.cg(fisher_vector_product,
                                        g,
                                        cg_iters=pms.cg_iters)
                    shs = 0.5 * stepdir.dot(
                        fisher_vector_product(stepdir))  # theta
                    fullstep = stepdir * np.sqrt(2.0 * pms.max_kl / shs)
                    neggdotstepdir = -g.dot(stepdir)

                    def loss(th):
                        self.sff(th)
                        return self.session.run(self.losses, feed_dict=feed)

                    surr_prev, kl_prev, ent_prev = loss(thprev)
                    mean_advant = np.mean(advant_n)
                    theta = linesearch(loss, thprev, fullstep, neggdotstepdir)
                    self.sff(theta)

                    surrafter, kloldnew, entnew = self.session.run(
                        self.losses, feed_dict=feed)

                    stats = {}

                    numeptotal += len(episoderewards)

                    stats["average_episode_std"] = average_episode_std
                    stats["sum steps of episodes"] = sample_data[
                        "sum_episode_steps"]
                    stats["Total number of episodes"] = numeptotal
                    stats[
                        "Average sum of rewards per episode"] = episoderewards.mean(
                        )
                    # stats["Entropy"] = entropy
                    # exp = explained_variance(np.array(baseline_n), np.array(returns_n))
                    # stats["Baseline explained"] = exp
                    stats["Time elapsed"] = "%.2f mins" % (
                        (time.time() - start_time) / 60.0)
                    stats["KL between old and new distribution"] = kloldnew
                    stats["Surrogate loss"] = surrafter
                    stats["Surrogate loss prev"] = surr_prev
                    stats["entropy"] = ent_prev
                    stats["mean_advant"] = mean_advant
                    log_data = [
                        average_episode_std,
                        len(episoderewards), numeptotal,
                        episoderewards.mean(), kloldnew, surrafter, surr_prev,
                        surrafter - surr_prev, ent_prev, mean_advant
                    ]
                    self.logger.log_row(log_data)
                    for k, v in stats.iteritems():
                        print(k + ": " + " " * (40 - len(k)) + str(v))
                        # if entropy != entropy:
                        #     exit(-1)
                        # if exp > 0.95:
                        #     self.train = False
            self.save_model("iter" + str(i))
            i += 1

    def test(self, model_name):
        self.load_model(model_name)
        if pms.record_movie:
            for i in range(100):
                self.storage.get_single_path()
            self.env.env.monitor.close()
            if pms.upload_to_gym:
                gym.upload("log/trpo",
                           algorithm_id='alg_8BgjkAsQRNiWu11xAhS4Hg',
                           api_key='sk_IJhy3b2QkqL3LWzgBXoVA')
        else:
            for i in range(50):
                self.storage.get_single_path()

    def save_model(self, model_name):
        self.saver.save(self.session, "checkpoint/" + model_name + ".ckpt")

    def load_model(self, model_name):
        try:
            self.saver.restore(self.session, model_name)
        except:
            print "load model %s fail" % (model_name)