示例#1
0
def feed_forward(
    state, data_shape, num_layers=3, activation=None, cut_gradient=False):
  """Create a model returning unnormalized MSE distribution."""
  hidden = state
  if cut_gradient:
    hidden = tf.stop_gradient(hidden)
  for _ in range(num_layers):
    hidden = tf.layers.dense(hidden, 100, tf.nn.relu)                    # e.g. state:shape(40,50,1,230)-->hidden:shape(40,50,1,100)
  mean = tf.layers.dense(hidden, int(np.prod(data_shape)), activation)   # e.g. --> mean:shape(40,50,1,1)
  mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape)          # e.g. mean:shape(40,50,1,1)
  dist = tools.MSEDistribution(mean)
  dist = tfd.Independent(dist, len(data_shape))
  return dist
示例#2
0
def decoder(state, data_shape):
  """Compute the data distribution of an observation from its state."""
  #hidden = keras.layers.Dense(500, activation='relu')(state)
  #hidden = keras.layers.Dense(500, activation='relu')(hidden)
  #hidden = keras.layers.Dense(26)(hidden)
  hidden = tf.layers.dense(state, 500, tf.nn.relu)
  hidden = tf.layers.dense(hidden, 500, tf.nn.relu)
  hidden = tf.layers.dense(hidden, 26, None)
  mean = hidden
  mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape)
  dist = tools.MSEDistribution(mean)
  dist = tfd.Independent(dist, len(data_shape))
  return dist
示例#3
0
def decoder(state, data_shape):
  """Compute the data distribution of an observation from its state."""
  kwargs = dict(strides=2, activation=tf.nn.relu)
  hidden = tf.layers.dense(state, 1024, None)
  hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value])
  hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs)
  hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs)
  hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs)
  hidden = tf.layers.conv2d_transpose(hidden, 3, 6, strides=2)
  mean = hidden
  assert mean.shape[1:].as_list() == [64, 64, 3], mean.shape
  mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape)
  dist = tools.MSEDistribution(mean)
  dist = tfd.Independent(dist, len(data_shape))
  return dist
示例#4
0
def decoder(state, data_shape):
    """Compute the data distribution of an observation from its state."""
    kwargs2 = dict(strides=2, activation=tf.nn.relu)
    kwargs3 = dict(strides=3, activation=tf.nn.relu)
    kwargs1 = dict(strides=1, activation=tf.nn.relu)
    hidden = tf.layers.dense(state, 1024, None)
    hidden = tf.reshape(hidden, [-1, 1, 1, hidden.shape[-1].value])

    if obs_size == (32, 32):
        hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs1)
        hidden = tf.layers.conv2d_transpose(hidden, 64, 4, **kwargs2)
        hidden = tf.layers.conv2d_transpose(hidden, 32, 4, **kwargs1)
        hidden = tf.layers.conv2d_transpose(hidden,
                                            num_channels_x,
                                            4,
                                            strides=2)

    elif obs_size == (64, 64):
        hidden = tf.layers.conv2d_transpose(hidden, 128, 5, **kwargs2)
        hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs2)
        hidden = tf.layers.conv2d_transpose(hidden, 32, 6, **kwargs2)
        hidden = tf.layers.conv2d_transpose(hidden,
                                            num_channels_x,
                                            6,
                                            strides=2)

    # elif obs_size == (128,128):
    #   hidden = tf.layers.conv2d_transpose(hidden, 128, 6, **kwargs2)
    #   hidden = tf.layers.conv2d_transpose(hidden, 64, 5, **kwargs3)
    #   hidden = tf.layers.conv2d_transpose(hidden, 32, 5, **kwargs3)
    #   hidden = tf.layers.conv2d_transpose(hidden, num_channels_x, 6, strides=2)

    elif obs_size == (128, 128):
        hidden = tf.layers.conv2d_transpose(hidden, 256, 4, **kwargs1)
        hidden = tf.layers.conv2d_transpose(hidden, 256, 4, **kwargs1)
        hidden = tf.layers.conv2d_transpose(hidden, 128, 4, **kwargs1)
        hidden = tf.layers.conv2d_transpose(hidden, 128, 3, **kwargs1)
        hidden = tf.layers.conv2d_transpose(hidden, 64, 4, **kwargs2)
        hidden = tf.layers.conv2d_transpose(
            hidden, 64, 4, **kwargs1)  # ~=  pixels * stride + kernel_size
        hidden = tf.layers.conv2d_transpose(hidden, 32, 4, **kwargs2)
        hidden = tf.layers.conv2d_transpose(hidden, 32, 4, **kwargs1)
        hidden = tf.layers.conv2d_transpose(hidden,
                                            num_channels_x,
                                            4,
                                            strides=2)

    mean = hidden

    if obs_size == (32, 32):
        assert mean.shape[1:].as_list() == [32, 32, num_channels_x], mean.shape
    elif obs_size == (64, 64):
        assert mean.shape[1:].as_list() == [64, 64, num_channels_x], mean.shape
    elif obs_size == (128, 128):
        assert mean.shape[1:].as_list() == [128, 128,
                                            num_channels_x], mean.shape

    mean = tf.reshape(mean, tools.shape(state)[:-1] + data_shape)
    dist = tools.MSEDistribution(mean)
    dist = tfd.Independent(dist, len(data_shape))
    return dist