コード例 #1
0
ファイル: rainbow.py プロジェクト: bekerov/SRLF
    def train(self):
        cmd_server = 'redis-server --port 12000'
        p = subprocess.Popen(cmd_server, shell=True, preexec_fn=os.setsid)
        self.variables_server = Redis(port=12000)
        means = "-"
        stds = "-"
        if self.scale != 'off':
            if self.timestep == 0:
                print("Time to measure features!")
                if self.distributed:
                    worker_args = \
                        {
                            'config': self.config,
                            'test_mode': False,
                        }
                    hlp.launch_workers(worker_args, self.n_workers)
                    paths = []
                    for i in range(self.n_workers):
                        paths += hlp.load_object(self.variables_server.get("paths_{}".format(i)))
                else:
                    self.test_mode = False
                    self.make_rollout()
                    paths = self.paths

                for path in paths:
                    self.sums += path["sumobs"]
                    self.sumsqrs += path["sumsqrobs"]
                    self.sumtime += len(path["rewards"])

            stds = np.sqrt((self.sumsqrs - np.square(self.sums) / self.sumtime) / (self.sumtime - 1))
            means = self.sums / self.sumtime
            print("Init means: {}".format(means))
            print("Init stds: {}".format(stds))
            self.variables_server.set("means", hlp.dump_object(means))
            self.variables_server.set("stds", hlp.dump_object(stds))
            self.sess.run(self.norm_set_op, feed_dict=dict(zip(self.norm_phs, [means, stds])))
        print("Let's go!")
        self.update_target_weights(alpha=1.0)
        index_replay = 0
        iteration = 0
        episode = 0
        idxs_range = np.arange(self.xp_size)
        xp_replay_state = np.zeros(shape=(self.xp_size, self.env.get_observation_space()))
        xp_replay_next_state = np.zeros(shape=(self.xp_size, self.env.get_observation_space()))
        xp_replay_reward = np.zeros(shape=(self.xp_size,))
        xp_replay_action = np.zeros(shape=(self.xp_size,))
        xp_replay_terminal = np.zeros(shape=(self.xp_size,))
        if self.prioritized:
            xp_replay_priority = np.zeros(shape=(self.xp_size,))
            self.max_prior = 1
        start_time = time.time()
        self.last_state = self.env.reset()
        discounts = self.gamma ** np.arange(self.n_steps)
        self.last_rewards = np.zeros(shape=(self.n_steps,))
        self.last_states = np.zeros(shape=(self.n_steps, self.n_features))
        self.last_actions = np.zeros(shape=(self.n_steps, ))
        buffer_index = 0
        env = self.env
        while True:
            if iteration <= self.random_steps:
                actions = env.env.action_space.sample()
            else:
                actions = self.act(env.features, exploration=True)
            self.last_states[buffer_index] = env.features.reshape(-1)
            self.last_actions[buffer_index] = actions
            env.step([actions])
            self.last_rewards[buffer_index] = env.reward
            buffer_index = (buffer_index + 1) % self.n_steps

            if env.timestamp >= self.n_steps:
                xp_replay_state[index_replay] = np.copy(self.last_states[buffer_index])
                xp_replay_next_state[index_replay] = env.features.reshape(-1)
                discounted_return = np.sum(discounts*self.last_rewards[np.roll(np.arange(self.n_steps), -(buffer_index))])
                xp_replay_reward[index_replay] = discounted_return
                xp_replay_action[index_replay] = self.last_actions[buffer_index]
                xp_replay_terminal[index_replay] = env.done
                if self.prioritized:
                    xp_replay_priority[index_replay] = self.max_prior
                index_replay = (index_replay + 1) % self.xp_size

            if env.done or env.timestamp > self.timesteps_per_launch:
                episode += 1
                print("Episode #{}".format(episode), env.get_total_reward())
                self.train_scores.append(env.get_total_reward())
                for i in range(1, self.n_steps):
                    buffer_index = (buffer_index + 1) % self.n_steps

                    xp_replay_state[index_replay] = np.copy(self.last_states[buffer_index])
                    xp_replay_next_state[index_replay] = env.features.reshape(-1)
                    discounted_return = np.sum(
                        discounts[:self.n_steps-i] * self.last_rewards[np.roll(np.arange(self.n_steps), -(buffer_index))[:self.n_steps-i]])
                    xp_replay_reward[index_replay] = discounted_return
                    xp_replay_action[index_replay] = self.last_actions[buffer_index]
                    xp_replay_terminal[index_replay] = env.done
                    index_replay = (index_replay + 1) % self.xp_size
                env.reset()
                self.last_rewards = np.zeros(shape=(self.n_steps,))
                self.last_states = np.zeros(shape=(self.n_steps, self.n_features))
                self.last_actions = np.zeros(shape=(self.n_steps,))
                buffer_index = 0

            self.last_state = env.features
            if iteration % 1000 == 0:
                print("Iteration #{}".format(iteration))
                self.save(self.config[:-5])

            if iteration > self.random_steps:
                if self.prioritized:
                    max_id = np.min([xp_replay_state.shape[0], iteration])
                    probs = xp_replay_priority[:max_id]/np.sum(xp_replay_priority[:max_id])
                    idxs = np.random.choice(idxs_range[:max_id], size=self.batch_size, p=probs)
                    importance_weights = (1/(max_id*probs[idxs]))**self.prior_beta
                else:
                    idxs = np.random.randint(np.min([xp_replay_state.shape[0], iteration]), size=self.batch_size)
                    importance_weights = np.ones(shape=(self.batch_size,))

                state_batch = xp_replay_state[idxs]
                next_state_batch = xp_replay_next_state[idxs]
                action_batch = xp_replay_action[idxs]
                reward_batch = xp_replay_reward[idxs]
                done_batch = xp_replay_terminal[idxs]

                feed_dict = {
                    self.state_input: state_batch,
                    self.next_state_input: next_state_batch,
                    self.action_input: action_batch,
                    self.importance_weights: importance_weights
                }
                target_atom_probs = self.sess.run(self.target_atom_probs, feed_dict)
                target_atom_probs = np.exp(target_atom_probs)

                if not self.double:
                    target_q_values = target_atom_probs * np.tile(np.arange(self.n_atoms).reshape((1, 1, self.n_atoms)), [self.batch_size, 1, 1])
                    target_q_values = np.sum(target_q_values, axis=2)
                    target_greedy_actions = np.argmax(target_q_values, axis=1).astype(np.int32).reshape((-1, 1))
                    target_probs = target_atom_probs[np.arange(self.batch_size).reshape((-1, 1)), target_greedy_actions]
                else:
                    feed_dict[self.state_input] = next_state_batch
                    atom_probs = self.sess.run(self.atom_probs, feed_dict)
                    atom_probs = np.exp(atom_probs)
                    q_values = atom_probs * np.tile(np.arange(self.n_atoms).reshape((1, 1, self.n_atoms)),
                                                                  [self.batch_size, 1, 1])
                    q_values = np.sum(q_values, axis=2)
                    greedy_actions = np.argmax(q_values, axis=1).astype(np.int32).reshape((-1, 1))
                    target_probs = target_atom_probs[np.arange(self.batch_size).reshape((-1, 1)), greedy_actions]
                    feed_dict[self.state_input] = state_batch

                atom_values = np.arange(self.n_atoms, dtype=np.float32).reshape((-1, self.n_atoms))
                atom_values = 2 * self.max_q_magnitude * (
                np.tile(atom_values, [self.batch_size, 1]) / (self.n_atoms - 1) - 0.5)
                atom_new_values = np.clip((self.gamma**self.n_steps) * atom_values * (1-done_batch).reshape(-1, 1) + reward_batch.reshape((-1, 1)),
                                                   - self.max_q_magnitude, self.max_q_magnitude)
                new_positions = ((atom_new_values / (2 * self.max_q_magnitude) + 0.5) * (self.n_atoms - 1)).reshape((-1))
                lower = np.floor(new_positions).astype(np.int32).reshape(-1)
                upper = np.floor(new_positions).astype(np.int32).reshape(-1) + 1

                final_target_probs = np.zeros(shape=(self.batch_size, self.n_atoms+1, self.n_atoms))
                final_target_probs[np.sort(np.tile(np.arange(self.batch_size), [self.n_atoms])), lower, np.tile(np.arange(self.n_atoms), [self.batch_size])] += (upper-new_positions) * target_probs.reshape((-1))
                final_target_probs[np.sort(np.tile(np.arange(self.batch_size), [self.n_atoms])), upper, np.tile(np.arange(self.n_atoms), [self.batch_size])] += (new_positions-lower) * target_probs.reshape((-1))

                final_target_probs = np.sum(final_target_probs, axis=2)[:, :-1]
                feed_dict[self.target_probs] = final_target_probs
                KLs = self.sess.run([self.loss, self.train_op], feed_dict)[0]
                if self.prioritized:
                    xp_replay_priority[idxs] = KLs ** self.prior_alpha
                self.update_target_weights()

                if iteration % self.test_every == 0:
                    weights = self.get_weights()
                    for i, weight in enumerate(weights):
                        self.variables_server.set("weight_" + str(i), hlp.dump_object(weight))
                    print("Time to test!")
                    if self.distributed:
                        weights = self.get_weights()
                        for i, weight in enumerate(weights):
                            self.variables_server.set("weight_" + str(i), hlp.dump_object(weight))
                        worker_args = \
                            {
                                'config': self.config,
                                'test_mode': True,
                            }
                        hlp.launch_workers(worker_args, self.n_workers)
                        paths = []
                        for i in range(self.n_workers):
                            paths += hlp.load_object(self.variables_server.get("paths_{}".format(i)))
                    else:
                        self.test_mode = True
                        self.make_rollout()
                        paths = self.paths

                    total_rewards = np.array([path["total"] for path in paths])
                    eplens = np.array([len(path["rewards"]) for path in paths])
                    print("""
-------------------------------------------------------------
Mean test score:           {test_scores}
Mean test episode length:  {test_eplengths}
Max test score:            {max_test}
Time for iteration:        {tt}
Mean of features:          {means}
Std of features:           {stds}
-------------------------------------------------------------
                                    """.format(
                        means=means,
                        stds=stds,
                        test_scores=np.mean(total_rewards),
                        test_eplengths=np.mean(eplens),
                        max_test=np.max(total_rewards),
                        tt=time.time() - start_time
                    ))
                    start_time = time.time()
                    self.test_scores.append(np.mean(total_rewards))

            iteration += 1
            self.timestep += 1
コード例 #2
0
ファイル: a3c_discrete.py プロジェクト: bekerov/SRLF
    def train(self):
        cmd_server = 'redis-server --port 12000'
        p = subprocess.Popen(cmd_server, shell=True, preexec_fn=os.setsid)
        self.variables_server = Redis(port=12000)
        means = "-"
        stds = "-"
        if self.scale != 'off':
            if self.timestep == 0:
                print("Time to measure features!")
                if self.distributed:
                    worker_args = \
                        {
                            'config': self.config,
                            'test_mode': False,
                        }
                    hlp.launch_workers(worker_args, self.n_workers)
                    paths = []
                    for i in range(self.n_workers):
                        paths += hlp.load_object(
                            self.variables_server.get("paths_{}".format(i)))
                else:
                    self.test_mode = False
                    self.make_rollout()
                    paths = self.paths

                for path in paths:
                    self.sums += path["sumobs"]
                    self.sumsqrs += path["sumsqrobs"]
                    self.sumtime += path["observations"].shape[0]

            stds = np.sqrt(
                (self.sumsqrs - np.square(self.sums) / self.sumtime) /
                (self.sumtime - 1))
            means = self.sums / self.sumtime
            print("Init means: {}".format(means))
            print("Init stds: {}".format(stds))
            self.variables_server.set("means", hlp.dump_object(means))
            self.variables_server.set("stds", hlp.dump_object(stds))
            self.sess.run(self.norm_set_op,
                          feed_dict=dict(zip(self.norm_phs, [means, stds])))

        weights = self.get_weights()
        for i, weight in enumerate(weights):
            self.variables_server.set("weight_" + str(i),
                                      hlp.dump_object(weight))
            self.variables_server.set('momentum_{}'.format(i),
                                      hlp.dump_object(np.zeros(weight.shape)))
            self.variables_server.set('velocity_{}'.format(i),
                                      hlp.dump_object(np.zeros(weight.shape)))
        self.variables_server.set('update_steps', hlp.dump_object(0))

        worker_args = \
            {
                'config': self.config,
                'test_mode': False,
            }
        hlp.launch_workers(worker_args,
                           self.n_workers,
                           command='work',
                           wait=False)

        while True:
            time.sleep(self.test_every)
            print("Time for testing!")
            if self.distributed:
                worker_args = \
                    {
                        'config': self.config,
                        'test_mode': True,
                    }
                hlp.launch_workers(worker_args, self.n_workers)
                paths = []
                for i in range(self.n_workers):
                    paths += hlp.load_object(
                        self.variables_server.get("paths_{}".format(i)))
            else:
                self.test_mode = True
                self.make_rollout()
                paths = self.paths

            total_rewards = np.array([path["total"] for path in paths])
            eplens = np.array([len(path["rewards"]) for path in paths])

            print("""
-------------------------------------------------------------
Mean test score:           {test_scores}
Mean test episode length:  {test_eplengths}
Max test score:            {max_test}
Number of train episodes:  {number}
Mean of features:          {means}
Std of features:           {stds}
-------------------------------------------------------------
                """.format(means=means,
                           stds=stds,
                           test_scores=np.mean(total_rewards),
                           test_eplengths=np.mean(eplens),
                           max_test=np.max(total_rewards),
                           number=self.variables_server.llen('results')))
            self.timestep += 1
            self.train_scores = [
                hlp.load_object(res)
                for res in self.variables_server.lrange('results', 0, -1)
            ][::-1]

            self.test_scores.append(np.mean(total_rewards))
            if self.timestep % self.save_every == 0:
                self.save(self.config[:-5])
コード例 #3
0
    def train(self):
        cmd_server = 'redis-server --port 12000'
        p = subprocess.Popen(cmd_server, shell=True, preexec_fn=os.setsid)
        self.variables_server = Redis(port=12000)
        means = "-"
        stds = "-"
        if self.scale != 'off':
            if self.timestep == 0:
                print("Time to measure features!")
                if self.distributed:
                    worker_args = \
                        {
                            'config': self.config,
                            'test_mode': False,
                        }
                    hlp.launch_workers(worker_args, self.n_workers)
                    paths = []
                    for i in range(self.n_workers):
                        paths += hlp.load_object(
                            self.variables_server.get("paths_{}".format(i)))
                else:
                    self.test_mode = False
                    self.make_rollout()
                    paths = self.paths

                for path in paths:
                    self.sums += path["sumobs"]
                    self.sumsqrs += path["sumsqrobs"]
                    self.sumtime += len(path["rewards"])
                stds = np.sqrt(
                    (self.sumsqrs - np.square(self.sums) / self.sumtime) /
                    (self.sumtime - 1))
                means = self.sums / self.sumtime
                print("Init means: {}".format(means))
                print("Init stds: {}".format(stds))
                self.variables_server.set("means", hlp.dump_object(means))
                self.variables_server.set("stds", hlp.dump_object(stds))
                self.sess.run(self.norm_set_op,
                              feed_dict=dict(zip(self.norm_phs,
                                                 [means, stds])))
        while True:
            print("Iteration {}".format(self.timestep))
            start_time = time.time()
            weight_noises = []
            random.seed()
            seed_for_random = random.randint(0, np.iinfo(np.int32).max)
            np.random.seed(seed_for_random)
            seeds = np.random.randint(-np.iinfo(np.int32).min +
                                      np.iinfo(np.int32).max,
                                      size=self.n_tasks_all)
            self.variables_server.set("seeds", hlp.dump_object(seeds))

            weights = self.get_weights()
            for i, weight in enumerate(weights):
                self.variables_server.set("weight_" + str(i),
                                          hlp.dump_object(weight))
                weight_noises.append(
                    np.empty((self.n_tasks_all, ) + weight.shape))

            for index in range(self.n_tasks_all):
                np.random.seed(seeds[index])
                for i, weight in enumerate(weights):
                    weight_noises[i][index] = np.random.normal(
                        size=weight.shape)

            if self.distributed:
                weights = self.get_weights()
                for i, weight in enumerate(weights):
                    self.variables_server.set("weight_" + str(i),
                                              hlp.dump_object(weight))
                worker_args = \
                    {
                        'config': self.config,
                        'test_mode': False,
                    }
                hlp.launch_workers(worker_args,
                                   self.n_workers,
                                   command='rollout_with_noise')
                paths = []
                for i in range(self.n_workers):
                    paths += hlp.load_object(
                        self.variables_server.get("paths_{}".format(i)))
            else:
                self.test_mode = False
                self.make_rollout()
                paths = self.paths

            scores = []
            train_lengths = []
            for i in range(self.n_tasks_all):
                scores.append(
                    hlp.load_object(
                        self.variables_server.get("scores_" + str(i))))
                train_lengths.append(
                    hlp.load_object(
                        self.variables_server.get("eplen_" + str(i))))
                scores.append(
                    hlp.load_object(
                        self.variables_server.get("scores_" + str(-i))))
                train_lengths.append(
                    hlp.load_object(
                        self.variables_server.get("eplen_" + str(-i))))

            scores = np.array(scores)
            train_mean_score = np.mean(scores)
            ranks = np.zeros(shape=scores.shape)

            if self.normalize == 'ranks':
                ranks[np.argsort(scores)] = np.arange(
                    ranks.shape[0], dtype=np.float32) / (ranks.shape[0] - 1)
                ranks -= 0.5
            elif self.normalize == 'center':
                ranks = scores[:]
                ranks -= train_mean_score
                ranks /= (np.std(ranks, ddof=1) + 0.001)

            gradients = [np.zeros(w.get_shape()) for w in self.weights]
            for i, weight in enumerate(weights):
                for index in 2 * np.arange(seeds.shape[0]):
                    gradients[i] += weight_noises[i][index // 2] * (
                        ranks[index] - ranks[index + 1]) / self.n_tasks_all
                gradients[i] -= self.l1_reg * weights[i]

            if self.adam:
                self.apply_adam_updates(gradients)
            else:
                for i, weight in enumerate(weights):
                    weights[i] += self.learning_rate * gradients[i]
                self.sess.run(self.set_op,
                              feed_dict=dict(zip(self.weights_phs, weights)))

            print("Time to testing!")

            if self.distributed:
                weights = self.get_weights()
                for i, weight in enumerate(weights):
                    self.variables_server.set("weight_" + str(i),
                                              hlp.dump_object(weight))
                worker_args = \
                    {
                        'config': self.config,
                        'test_mode': True,
                    }
                hlp.launch_workers(worker_args, self.n_workers)
                paths = []
                for i in range(self.n_workers):
                    paths += hlp.load_object(
                        self.variables_server.get("paths_{}".format(i)))
            else:
                self.test_mode = True
                self.make_rollout()
                paths = self.paths

            total_rewards = np.array([path["total"] for path in paths])
            eplens = np.array([len(path["rewards"]) for path in paths])
            if self.scale:
                for i in range(self.n_tasks_all):
                    self.sums += hlp.load_object(
                        self.variables_server.get("sum_{}".format(i)))
                    self.sumsqrs += hlp.load_object(
                        self.variables_server.get("sumsqr_{}".format(i)))
                self.sumtime += np.sum(train_lengths)
                stds = np.sqrt(
                    (self.sumsqrs - np.square(self.sums) / self.sumtime) /
                    (self.sumtime - 1))
                means = self.sums / self.sumtime
                self.variables_server.set("means", hlp.dump_object(means))
                self.variables_server.set("stds", hlp.dump_object(stds))
                self.sess.run(self.norm_set_op,
                              feed_dict=dict(zip(self.norm_phs,
                                                 [means, stds])))

            print("""
-------------------------------------------------------------
Mean test score:           {test_scores}
Mean train score:          {train_scores}
Mean test episode length:  {test_eplengths}
Mean train episode length: {train_eplengths}
Max test score:            {max_test}
Max train score:           {max_train}
Mean of features:          {means}
Std of features:           {stds}
Time for iteration:        {tt}
-------------------------------------------------------------
                """.format(means=means,
                           stds=stds,
                           test_scores=np.mean(total_rewards),
                           test_eplengths=np.mean(eplens),
                           train_scores=train_mean_score,
                           train_eplengths=np.mean(train_lengths),
                           max_test=np.max(total_rewards),
                           max_train=np.max(scores),
                           tt=time.time() - start_time))
            self.train_scores.append(train_mean_score)
            self.test_scores.append(np.mean(total_rewards))
            if self.timestep % self.save_every == 0:
                self.save(self.config[:-5])
コード例 #4
0
ファイル: trpo_discrete.py プロジェクト: bekerov/SRLF
    def train(self):
        cmd_server = 'redis-server --port 12000'
        p = subprocess.Popen(cmd_server, shell=True, preexec_fn=os.setsid)
        self.variables_server = Redis(port=12000)
        means = "-"
        stds = "-"
        if self.scale != 'off':
            if self.timestep == 0:
                print("Time to measure features!")
                if self.distributed:
                    worker_args = \
                        {
                            'config': self.config,
                            'test_mode': False,
                        }
                    hlp.launch_workers(worker_args, self.n_workers)
                    paths = []
                    for i in range(self.n_workers):
                        paths += hlp.load_object(
                            self.variables_server.get("paths_{}".format(i)))
                else:
                    self.test_mode = False
                    self.make_rollout()
                    paths = self.paths

                for path in paths:
                    self.sums += path["sumobs"]
                    self.sumsqrs += path["sumsqrobs"]
                    self.sumtime += path["observations"].shape[0]

            stds = np.sqrt(
                (self.sumsqrs - np.square(self.sums) / self.sumtime) /
                (self.sumtime - 1))
            means = self.sums / self.sumtime
            print("Init means: {}".format(means))
            print("Init stds: {}".format(stds))
            self.variables_server.set("means", hlp.dump_object(means))
            self.variables_server.set("stds", hlp.dump_object(stds))
            self.sess.run(self.norm_set_op,
                          feed_dict=dict(zip(self.norm_phs, [means, stds])))
        while True:
            print("Iteration {}".format(self.timestep))
            start_time = time.time()

            if self.distributed:
                weights = self.get_weights()
                for i, weight in enumerate(weights):
                    self.variables_server.set("weight_" + str(i),
                                              hlp.dump_object(weight))
                worker_args = \
                    {
                        'config': self.config,
                        'test_mode': False,
                    }
                hlp.launch_workers(worker_args, self.n_workers)
                paths = []
                for i in range(self.n_workers):
                    paths += hlp.load_object(
                        self.variables_server.get("paths_{}".format(i)))
            else:
                self.test_mode = False
                self.make_rollout()
                paths = self.paths

            observations = np.concatenate(
                [path["observations"] for path in paths])
            actions = np.concatenate([path["action_tuples"] for path in paths])
            action_dists = []
            for _ in range(len(self.n_actions)):
                action_dists.append([])
            returns = []
            advantages = []
            for path in paths:
                self.sums += path["sumobs"]
                self.sumsqrs += path["sumsqrobs"]
                self.sumtime += path["rewards"].shape[0]
                dists = path["dist_tuples"]

                for i in range(len(self.n_actions)):
                    action_dists[i] += [dist[i][0] for dist in dists]
                returns += hlp.discount(path["rewards"], self.gamma,
                                        path["timestamps"]).tolist()
                values = self.sess.run(
                    self.value,
                    feed_dict={self.state_input: path["observations"]})
                values = np.append(values,
                                   0 if path["terminated"] else values[-1])
                deltas = (path["rewards"] + self.gamma * values[1:] -
                          values[:-1])
                advantages += hlp.discount(deltas, self.gamma,
                                           path["timestamps"]).tolist()
            returns = np.array(returns)
            advantages = np.array(advantages)

            if self.normalize == 'ranks':
                ranks = np.zeros_like(advantages)
                ranks[np.argsort(advantages)] = np.arange(
                    ranks.shape[0], dtype=np.float32) / (ranks.shape[0] - 1)
                ranks -= 0.5
                advantages = ranks[:]
            elif self.normalize == 'center':
                advantages -= np.mean(advantages)
                advantages /= (np.std(advantages, ddof=1) + 0.001)

            feed_dict = {
                self.state_input: observations,
                self.targets["return"]: returns,
                self.targets["advantage"]: advantages
            }

            for i in range(len(self.n_actions)):
                feed_dict[self.targets["old_dist_{}".format(i)]] = np.array(
                    action_dists[i])
                feed_dict[self.targets["action_{}".format(i)]] = actions[:, i]

            for i in range(self.value_updates):
                self.sess.run(self.value_train_op, feed_dict)

            train_rewards = np.array([path["rewards"].sum() for path in paths])
            train_lengths = np.array([len(path["rewards"]) for path in paths])

            thprev = self.get_flat()

            def fisher_vector_product(p):
                feed_dict[self.targets["flat_tangent"]] = p
                return self.sess.run(self.fisher_vector_product,
                                     feed_dict) + 0.1 * p

            g = self.sess.run(self.policy_grad, feed_dict)
            stepdir = hlp.conjugate_gradient(fisher_vector_product, -g)

            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / self.max_kl)
            fullstep = stepdir / (lm + 1e-18)

            def loss_kl(th):
                self.set_from_flat(th)
                return self.sess.run([self.loss, self.KL], feed_dict=feed_dict)

            theta = hlp.linesearch(loss_kl, thprev, fullstep, self.max_kl)
            self.set_from_flat(theta)

            lossafter, kloldnew = self.sess.run([self.loss, self.KL],
                                                feed_dict=feed_dict)

            print("Time for testing!")

            if self.distributed:
                weights = self.get_weights()
                for i, weight in enumerate(weights):
                    self.variables_server.set("weight_" + str(i),
                                              hlp.dump_object(weight))
                worker_args = \
                    {
                        'config': self.config,
                        'test_mode': True,
                    }
                hlp.launch_workers(worker_args, self.n_workers)
                paths = []
                for i in range(self.n_workers):
                    paths += hlp.load_object(
                        self.variables_server.get("paths_{}".format(i)))
            else:
                self.test_mode = True
                self.make_rollout()
                paths = self.paths

            total_rewards = np.array([path["total"] for path in paths])
            eplens = np.array([len(path["rewards"]) for path in paths])

            if self.scale != 'full':
                stds = np.sqrt(
                    (self.sumsqrs - np.square(self.sums) / self.sumtime) /
                    (self.sumtime - 1))
                means = self.sums / self.sumtime
                self.variables_server.set("means", hlp.dump_object(means))
                self.variables_server.set("stds", hlp.dump_object(stds))
                self.sess.run(self.norm_set_op,
                              feed_dict=dict(zip(self.norm_phs,
                                                 [means, stds])))

            print("""
-------------------------------------------------------------
Mean test score:           {test_scores}
Mean train score:          {train_scores}
Mean test episode length:  {test_eplengths}
Mean train episode length: {train_eplengths}
Max test score:            {max_test}
Max train score:           {max_train}
KL between old and new     {kl}
Loss after update          {loss}
Mean of features:          {means}
Std of features:           {stds}
-------------------------------------------------------------
                """.format(means=means,
                           stds=stds,
                           test_scores=np.mean(total_rewards),
                           test_eplengths=np.mean(eplens),
                           train_scores=np.mean(train_rewards),
                           train_eplengths=np.mean(train_lengths),
                           max_test=np.max(total_rewards),
                           max_train=np.max(train_rewards),
                           kl=kloldnew,
                           loss=lossafter))
            self.timestep += 1
            self.train_scores.append(np.mean(train_rewards))
            self.test_scores.append(np.mean(total_rewards))
            if self.timestep % self.save_every == 0:
                self.save(self.config[:-5])