Beispiel #1
0
    def __init__(self, settings):
        self.settings       = update_settings(DEFAULT_SETTINGS, settings)

        # network and training
        self.q_network = parse_block(settings["model"])
        self.optimizer = parse_optimizer(settings["optimizer"])

        out_sh = self.q_network.output_shape()
        assert len(out_sh) == 2 and out_sh[0] is None, \
                "Output of the Discrete DeepQ must be (None, num_actions), where None corresponds to batch_size"
        self.num_actions      = out_sh[1]
        self.minipatch_size   = self.settings["minibatch_size"]

        self.train_every_nth              = self.settings['train_every_nth']
        self.discount_rate    = self.settings["discount_rate"]

        self.transitions_so_far        = 0
        self.exploration_period        = self.settings['exploration_period']
        self.random_action_probability = self.settings['random_action_probability']

        self.replay_buffer                = deque()
        self.store_every_nth              = self.settings['store_every_nth']
        self.replay_buffer_size           = self.settings['replay_buffer_size']

        self.target_network_update_rate   = self.settings['target_network_update_rate']

        self.summary_writer = None

        self.s = tf.Session()

        self.create_variables()
        self.s.run(tf.initialize_variables(
                self.q_network.variables() + self.target_q_network.variables()))
Beispiel #2
0
    def __init__(self, settings):
        self.settings = update_settings(DEFAULT_SETTINGS, settings)

        # network and training
        self.q_network = parse_block(settings["model"])
        self.optimizer = parse_optimizer(settings["optimizer"])

        out_sh = self.q_network.output_shape()
        assert len(out_sh) == 2 and out_sh[0] is None, \
                "Output of the Discrete DeepQ must be (None, num_actions), where None corresponds to batch_size"
        self.num_actions = out_sh[1]
        self.minipatch_size = self.settings["minibatch_size"]

        self.train_every_nth = self.settings['train_every_nth']
        self.discount_rate = self.settings["discount_rate"]

        self.transitions_so_far = 0
        self.exploration_period = self.settings['exploration_period']
        self.random_action_probability = self.settings[
            'random_action_probability']

        self.replay_buffer = deque()
        self.store_every_nth = self.settings['store_every_nth']
        self.replay_buffer_size = self.settings['replay_buffer_size']

        self.target_network_update_rate = self.settings[
            'target_network_update_rate']

        self.summary_writer = None

        self.s = tf.Session()

        self.create_variables()
        self.s.run(
            tf.initialize_variables(self.q_network.variables() +
                                    self.target_q_network.variables()))
Beispiel #3
0
 def test_basic(self):
     self.assertEqual(update_settings('siema', None), 'siema')
     self.assertEqual(update_settings(None, 'siema'), 'siema')
Beispiel #4
0
 def test_updates(self):
     self.assertEqual(update_settings({'a': 1}, {'b':2}), {'a':1, 'b':2})
     self.assertEqual(update_settings({'a': 1}, {'a':2}), {'a':2})
Beispiel #5
0
 def test_immutable(self):
     a = {'a': 1}
     b = {'b': 2}
     self.assertEqual(update_settings(a, b), {'a':1, 'b':2})
     self.assertEqual(a, {'a':1})
     self.assertEqual(b, {'b':2})