Esempio n. 1
0
class DeepQNetwork(object):
    def __init__(self, height, width, num_actions, state_frames, gamma, batch_size=32):
        self.gamma = gamma
        self.num_actions = num_actions
        self.input_var = T.tensor4('inputs')
        self.target_var = T.vector('targets')
        self.actions_var = T.matrix()

        self.input_shared = theano.shared(np.zeros([batch_size, state_frames, height, width]),
                                          dtype=theano.tensor.floatX)
        self.taret_shared = theano.shared(np.zeros([batch_size, 1]), dtype=theano.tensor.floatX)
        self.actions_shared = theano.shared(np.zeros([batch_size, num_actions]).astype('int32'),
                                            dtype=theano.tensor.int32)

        self.net = get_inference(self.input_var, num_actions, height, width, state_frames)
        self.prediction = (lasagne.layers.get_output(self.net) * self.actions_var).sum(axis=1)
        self.error = T.sqr(self.target_var - self.prediction).mean()

        params = lasagne.layers.get_all_params(self.net, trainable=True)
        #self.opt = lasagne.updates.rmsprop(
        #    self.error, params, learning_rate=0.0025, epsilon=0.01)
        #self.opt = lasagne.updates.adagrad(
        #    self.error, params, learning_rate=0.001
        #)
        self.opt = gen_updates_rmsprop(params, T.grad(self.error, params))

        self.train_step = theano.function([self.input_var, self.target_var, self.actions_var],
                                          self.error, updates=self.opt, allow_input_downcast=True)
        self.net_fun = theano.function([self.input_var], lasagne.layers.get_output(self.net),
                                       allow_input_downcast=True)

        self.monitor = Monitoring()

        self.pred_fun = theano.function([self.input_var, self.actions_var], self.prediction, allow_input_downcast=True)
        self.error_fun = theano.function([self.input_var, self.target_var, self.actions_var], self.error,
                                         allow_input_downcast=True)


    def get_best_action(self, example):
        vals = self.evaluate_network([example])[0]
        return np.argmax(vals), vals

    def evaluate_network(self, batch):
        return self.net_fun(batch)

    def train(self, batch):
        batch_size = batch['batch_size']
        self.monitor.report_action_start('get_target')
        target = self.get_target(batch)
        self.monitor.report_action_finish('get_target')
        actions = np.zeros([batch_size, self.num_actions])
        actions[:, list(batch['actions'])] = 1


        self.monitor.report_action_start('pred_fun')
        self.pred_fun(batch['states_1'], actions)
        self.monitor.report_action_finish('pred_fun')

        self.monitor.report_action_start('error_fun')
        self.error_fun(batch['states_1'], target, actions)
        self.monitor.report_action_finish('error_fun')



        self.monitor.report_action_start('train_step')
        error_val = self.train_step(batch['states_1'], target, actions)
        self.monitor.report_action_finish('train_step')


        return error_val

    def get_target(self, batch):
        batch_size = batch['batch_size']
        self.monitor.report_action_start('only_q')
        q_vals = self.evaluate_network(batch['states_2'])
        self.monitor.report_action_finish('only_q')
        best_actions = np.argmax(q_vals, 1)
        target = np.choose(best_actions, q_vals.T)
        for i in xrange(batch_size):
            if batch['terminations'][i]:
                mul = 0
            else:
                mul = self.gamma
            target[i] = batch['rewards'][i] + mul * target[i]
        return target
Esempio n. 2
0
class DeepQNetwork(object):
    def __init__(self, height, width, session, num_actions, state_frames, gamma, net_type=1,
                 optimizer=tf.train.AdamOptimizer(1e-6)):
        self.height = height
        self.width = width
        self.session = session
        self.num_actions = num_actions
        self.state_frames = state_frames
        self.gamma = gamma
        self.opt = optimizer
        self.train_counter = 0

        self.input_pl = tf.placeholder(tf.float32, [None, height, width, state_frames])
        self.target_pl = tf.placeholder(tf.float32, [None])
        self.actions_pl = tf.placeholder(tf.float32, [None, num_actions])

        if net_type == 1:
            self.network = get_inference(self.input_pl, num_actions, 'main_net')
        else:
            self.network = get_inference2(self.input_pl, num_actions, 'main_net')

        self.prediction = tf.reduce_sum(tf.mul(self.network.output, self.actions_pl), 1)
        #self.error = tf.reduce_mean(tf.square(self.target_pl - self.prediction))
        self.error = tf.square(self.target_pl - self.prediction)

        self.train_step = self.opt.minimize(self.error)

        self.session.run([tf.initialize_all_variables()])

        #todo: throw that away
        self.monitor = Monitoring()

    def get_best_action(self, example):
        vals = self.evaluate_network([example.transpose([1, 2, 0])])[0]
        return np.argmax(vals), vals

    def train(self, batch):
        self.train_counter += 1
        prepare_batch(batch)
        batch_size = batch['batch_size']
        self.monitor.report_action_start('get_batch')
        target = self.get_target(batch)
        self.monitor.report_action_finish('get_batch')
        actions = np.zeros([batch_size, self.num_actions])
        actions[:, list(batch['actions'])] = 1
        self.monitor.report_action_start('pred')
        _, error_val, pred = self.session.run(
            [self.train_step, self.error, self.prediction],
            feed_dict={self.target_pl: target, self.actions_pl: actions, self.input_pl: batch['states_1']})
        self.monitor.report_action_finish('pred')
        return error_val.mean()

    def get_target(self, batch):
        prepare_batch(batch)
        batch_size = batch['batch_size']
        q_vals = self.evaluate_network(batch['states_2'])
        best_actions = np.argmax(q_vals, 1)
        target = np.choose(best_actions, q_vals.T)
        for i in xrange(batch_size):
            if batch['terminations'][i]:
                mul = 0
            else:
                mul = self.gamma
            target[i] = batch['rewards'][i] + mul * target[i]
        return target

    def evaluate_network(self, states_batch):
        return self.session.run(self.network.output, feed_dict={self.input_pl: states_batch})