def feed_forward( state, data_shape, num_layers=3, activation=None, cut_gradient=False): """Create a model returning unnormalized MSE distribution.""" hidden = state if cut_gradient: hidden = tf.stop_gradient(hidden) for _ in range(num_layers): hidden = tf.layers.dense(hidden, 100, tf.nn.relu) # e.g. state:shape(40,50,1,230)-->hidden:shape(40,50,1,100) mean = tf.layers.dense(hidden, int(np.prod(data_shape)), activation) # e.g. --> mean:shape(40,50,1,1) mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape) # e.g. mean:shape(40,50,1,1) dist = tools.MSEDistribution(mean) dist = tfd.Independent(dist, len(data_shape)) return dist
def decoder(state, data_shape): """Compute the data distribution of an observation from its state.""" #hidden = keras.layers.Dense(500, activation='relu')(state) #hidden = keras.layers.Dense(500, activation='relu')(hidden) #hidden = keras.layers.Dense(26)(hidden) hidden = tf.layers.dense(state, 500, tf.nn.relu) hidden = tf.layers.dense(hidden, 500, tf.nn.relu) hidden = tf.layers.dense(hidden, 26, None) mean = hidden mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape) dist = tools.MSEDistribution(mean) dist = tfd.Independent(dist, len(data_shape)) return dist
def decoder(state, data_shape): """Compute the data distribution of an observation from its state.""" kwargs = dict(strides=2, activation=tf.nn.relu) hidden = tf.layers.dense(state, 1024, None) hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value]) hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs) hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs) hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs) hidden = tf.layers.conv2d_transpose(hidden, 3, 6, strides=2) mean = hidden assert mean.shape[1:].as_list() == [64, 64, 3], mean.shape mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape) dist = tools.MSEDistribution(mean) dist = tfd.Independent(dist, len(data_shape)) return dist
def decoder(state, data_shape): """Compute the data distribution of an observation from its state.""" kwargs2 = dict(strides=2, activation=tf.nn.relu) kwargs3 = dict(strides=3, activation=tf.nn.relu) kwargs1 = dict(strides=1, activation=tf.nn.relu) hidden = tf.layers.dense(state, 1024, None) hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value]) if obs_size == (32, 32): hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs1) hidden = tf.layers.conv2d_transpose(hidden, 64, 4, **kwargs2) hidden = tf.layers.conv2d_transpose(hidden, 32, 4, **kwargs1) hidden = tf.layers.conv2d_transpose(hidden, num_channels_x, 4, strides=2) elif obs_size == (64, 64): hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs2) hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs2) hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs2) hidden = tf.layers.conv2d_transpose(hidden, num_channels_x, 6, strides=2) # elif obs_size == (128,128): # hidden = tf.layers.conv2d_transpose(hidden, 128, 6, **kwargs2) # hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs3) # hidden = tf.layers.conv2d_transpose(hidden, 32, 5, **kwargs3) # hidden = tf.layers.conv2d_transpose(hidden, num_channels_x, 6, strides=2) elif obs_size == (128, 128): hidden = tf.layers.conv2d_transpose(hidden, 256, 4, **kwargs1) hidden = tf.layers.conv2d_transpose(hidden, 256, 4, **kwargs1) hidden = tf.layers.conv2d_transpose(hidden, 128, 4, **kwargs1) hidden = tf.layers.conv2d_transpose(hidden, 128, 3, **kwargs1) hidden = tf.layers.conv2d_transpose(hidden, 64, 4, **kwargs2) hidden = tf.layers.conv2d_transpose( hidden, 64, 4, **kwargs1) # ~= pixels * stride + kernel_size hidden = tf.layers.conv2d_transpose(hidden, 32, 4, **kwargs2) hidden = tf.layers.conv2d_transpose(hidden, 32, 4, **kwargs1) hidden = tf.layers.conv2d_transpose(hidden, num_channels_x, 4, strides=2) mean = hidden if obs_size == (32, 32): assert mean.shape[1:].as_list() == [32, 32, num_channels_x], mean.shape elif obs_size == (64, 64): assert mean.shape[1:].as_list() == [64, 64, num_channels_x], mean.shape elif obs_size == (128, 128): assert mean.shape[1:].as_list() == [128, 128, num_channels_x], mean.shape mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape) dist = tools.MSEDistribution(mean) dist = tfd.Independent(dist, len(data_shape)) return dist