def run_n_steps_and_update_test(self):

        n_steps = 10

        w = GridSolowWorker(name="test",
                            env=SolowEnv(),
                            policy_net=self.global_policy_net,
                            value_net=self.global_value_net,
                            shared_layer=self.shared_layer,
                            global_counter=self.global_counter,
                            discount_factor=self.discount_factor)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            w.state = w.env.reset()
            w.history = [w.state_processor.process_state(w.state)]

            transitions, local_t, global_t, probs, done = w.run_n_steps(
                n_steps, sess, max_seq_length=5)
            policy_net_loss, value_net_loss, policy_net_summaries, value_net_summaries, preds = w.update(
                transitions, sess, max_seq_length=5)
            np.testing.assert_array_almost_equal(
                np.squeeze(preds['probs']), np.squeeze(probs[::-1][:n_steps]))
            self.assertEqual(len(transitions), n_steps)
            self.assertIsNotNone(policy_net_loss)
            self.assertIsNotNone(value_net_loss)
            self.assertIsNotNone(policy_net_summaries)
            self.assertIsNotNone(value_net_summaries)

            transitions, local_t, global_t, probs, done = w.run_n_steps(
                n_steps, sess, max_seq_length=5)
            policy_net_loss, value_net_loss, policy_net_summaries, value_net_summaries, preds = w.update(
                transitions, sess, max_seq_length=5)
            np.testing.assert_array_almost_equal(np.squeeze(preds['probs']),
                                                 np.squeeze(probs[::-1]))
    def value_predict_test(self):
        w = SolowWorker(name="test",
                        env=SolowEnv(),
                        policy_net=self.global_policy_net,
                        value_net=self.global_value_net,
                        shared_layer=self.shared_layer,
                        global_counter=self.global_counter,
                        discount_factor=self.discount_factor)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            state = w.env.reset()
            temporal_state = w.state_processor.process_temporal_states(
                [w.process_state(state)])
            state_value = w._value_net_predict(
                state, temporal_state.reshape((1, self.temporal_size)), sess)
            self.assertEqual(state_value.shape, ())
    def policy_predict_test(self):
        w = GridSolowWorker(name="test",
                            env=SolowEnv(),
                            policy_net=self.global_policy_net,
                            value_net=self.global_value_net,
                            shared_layer=self.shared_layer,
                            global_counter=self.global_counter,
                            discount_factor=self.discount_factor)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            state = w.state_processor.process_state(w.env.reset())
            temporal_state = w.state_processor.process_temporal_states([state])
            preds = w.policy_net.predict(
                state.flatten(), temporal_state.reshape(
                    (1, self.temporal_size)), sess)
            probs = preds['probs'][0]

            self.assertEqual(probs.shape, (self.num_outputs, self.num_choices))
    def one_transition_test(self):

        n_steps = 1

        w = SolowWorker(name="test",
                        env=SolowEnv(),
                        policy_net=self.global_policy_net,
                        value_net=self.global_value_net,
                        shared_layer=self.shared_layer,
                        global_counter=self.global_counter,
                        discount_factor=self.discount_factor)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            w.state = w.env.reset()
            w.history = [w.process_state(w.state)]
            transitions, local_t, global_t, mus, done = w.run_n_steps(
                n_steps, sess, max_seq_length=5)
            transitions = [transitions[0]]
            w.update(transitions, sess, max_seq_length=5)
    def policy_predict_test(self):
        w = SolowWorker(name="test",
                        env=SolowEnv(),
                        policy_net=self.global_policy_net,
                        value_net=self.global_value_net,
                        shared_layer=self.shared_layer,
                        global_counter=self.global_counter,
                        discount_factor=self.discount_factor)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            state = w.process_state(w.env.reset())
            temporal_state = w.state_processor.process_temporal_states([state])
            preds = w.policy_net.predict(
                state.flatten(), temporal_state.reshape(
                    (1, self.temporal_size)), sess)
            mu = preds['mu']
            sig = preds['sigma']

            self.assertEqual(mu[0].shape, (self.num_actions, ))
            self.assertEqual(sig[0].shape, (self.num_actions, ))