def __init__(self, ob_space, ac_space): self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space)) for i in range(4): x = tf.layers.conv2d(inputs=x, filters=32, name="l{}".format(i + 1), kernel_size=(3, 3), padding='same', activation=tf.nn.relu) image_size = 32 size = 32 lstm = ConvLSTMCell(height=image_size, width=image_size, filters=size, kernel=[3, 3]) self.state_size = lstm.state_size step_size = tf.shape(self.x)[:1] c_init = np.zeros((1, lstm.state_size.c), np.float32) h_init = np.zeros((1, lstm.state_size.h), np.float32) self.state_init = [c_init, h_init] c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) self.state_in = [c_in, h_in] state_in = rnn.LSTMStateTuple(c_in, h_in) lstm_outputs, lstm_state = tf.nn.dynamic_rnn(lstm, x, initial_state=state_in, sequence_length=step_size, time_major=False) lstm_c, lstm_h = lstm_state x = tf.reshape(expand(lstm_outputs, height=image_size, width=image_size, filters=size), shape=[-1, image_size ** 2 * size]) self.logits = tf.layers.dense(inputs=x, units=ac_space, activation=None, name='action') self.vf = tf.reshape(tf.layers.dense(inputs=x, units=1, name='value'), [-1]) self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] self.sample = categorical_sample(self.logits, ac_space)[0, :] self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
def _build_graph(self): self.batch_size = tf.shape(self.image_pl)[0] # conv1 [batch, x, y, c] conv1 = tf.layers.conv2d(inputs=self.image_pl, filters=16, kernel_size=[5, 5], padding='same', activation=tf.nn.relu, name='conv1') # conv2 [batch, x, y, c] conv2 = tf.layers.conv2d(inputs=conv1, filters=16, kernel_size=[5, 5], padding='same', activation=tf.nn.relu, name='conv2') # conv3 [batch, x, y, c] conv3 = tf.layers.conv2d(inputs=conv2, filters=16, kernel_size=[3, 3], padding='same', activation=tf.nn.relu, name='conv3') # tiled [batch, timestep, c] tiled = tf.tile(tf.expand_dims(conv3, axis=1), multiples=[1, self.max_timesteps, 1, 1, 1]) concat = tf.concat([tiled, self.history_pl], axis=4) with tf.variable_scope('rnn'): # Make placeholder time major for RNN. (see https://github.com/tensorflow/tensorflow/pull/5142) rnn_input = tf.transpose(concat, (1, 0, 2, 3, 4)) rnn_cell = lambda: ConvLSTMCell(shape=[28, 28], filters=16, kernel=[3, 3]) rnn_layers = 2 multi_rnn_cell = tf.contrib.rnn.MultiRNNCell([rnn_cell() for _ in range(rnn_layers)]) self._rnn_zero_state_tensor = multi_rnn_cell.zero_state(self.batch_size, dtype=tf.float32) self._rnn_initial_state_pl = tuple(tf.contrib.rnn.LSTMStateTuple( tf.placeholder(tf.float32, shape=[None] + multi_rnn_cell.state_size[i].c.as_list(), name='rnn_initial_state_c'), tf.placeholder(tf.float32, shape=[None] + multi_rnn_cell.state_size[i].h.as_list(), name='rnn_initial_state_h')) for i in range(rnn_layers)) rnn_output, self._rnn_final_state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, inputs=rnn_input, sequence_length=self.duration_pl, initial_state=self._rnn_initial_state_pl, time_major=True) # Make placeholder batch major again after RNN. (see https://github.com/tensorflow/tensorflow/pull/5142) rnn_output = tf.transpose(rnn_output, (1, 0, 2, 3, 4)) self.action_logits = tf.reshape(tf.layers.dense(inputs=tf.reshape(rnn_output, shape=[-1, 28 * 28 * 16]), units=self.action_size * self.action_size, name='prediction'), shape=[-1, self.max_timesteps, self.action_size * self.action_size]) # Sample first_action = tf.squeeze(tf.slice(self.action_logits, begin=[0, 0, 0], size=[-1, 1, -1]), axis=1) # The action for the first timestep action_sample_dense = tf.squeeze( tf.multinomial(logits=first_action - tf.reduce_max(first_action, axis=[1], keep_dims=True), num_samples=1), [1]) self.action_sample_coords = tf.stack( [action_sample_dense // self.action_size, action_sample_dense % self.action_size], axis=1)
def _build_graph(self): # conv1 [batch, x, y, c] conv1 = tf.layers.conv2d(inputs=self.image_pl, filters=16, kernel_size=[5, 5], padding='same', activation=tf.nn.relu, name='conv1') # conv2 [batch, x, y, c] conv2 = tf.layers.conv2d(inputs=conv1, filters=16, kernel_size=[5, 5], padding='same', activation=tf.nn.relu, name='conv2') # conv3 [batch, x, y, c] conv3 = tf.layers.conv2d(inputs=conv2, filters=16, kernel_size=[3, 3], padding='same', activation=tf.nn.relu, name='conv3') # tiled [batch, timestep, c] tiled = tf.tile(tf.expand_dims(conv3, axis=1), multiples=[1, self.max_timesteps, 1, 1, 1]) concat = tf.concat([tiled, self.history_pl], axis=4) with tf.variable_scope('rnn'): # Make placeholder time major for RNN. (see https://github.com/tensorflow/tensorflow/pull/5142) rnn_input = tf.transpose(concat, (1, 0, 2, 3, 4)) rnn_cell = lambda: ConvLSTMCell(shape=[28, 28], filters=16, kernel=[3, 3]) rnn_layers = 10 multi_rnn_cell = tf.contrib.rnn.MultiRNNCell([rnn_cell() for _ in range(rnn_layers)]) self._rnn_zero_state = multi_rnn_cell.zero_state(self.batch_size, dtype=tf.float32) self._rnn_initial_state_pl = tuple(tf.contrib.rnn.LSTMStateTuple( tf.placeholder_with_default(self.rnn_zero_state[i].c, shape=[None] + multi_rnn_cell.state_size[i].c.as_list()), tf.placeholder_with_default(self.rnn_zero_state[i].h, shape=[None] + multi_rnn_cell.state_size[i].h.as_list())) for i in range(rnn_layers)) rnn_output, self._rnn_final_state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, inputs=rnn_input, sequence_length=self._duration_pl, initial_state=self._rnn_initial_state_pl, time_major=True) # Make placeholder batch major again after RNN. (see https://github.com/tensorflow/tensorflow/pull/5142) rnn_output = tf.transpose(rnn_output, (1, 0, 2, 3, 4)) self._prediction_logits = tf.reshape(tf.layers.dense(inputs=tf.reshape(rnn_output, shape=[-1, 28 * 28 * 16]), units=prediction_size ** 2, name='prediction'), shape=[-1, self.max_timesteps, self.prediction_size ** 2])