Exemplo n.º 1
0
def get_trained_q(trained_var):
  """Get the trained inference distribution :math:`q` (c.f. section "Notation"
  in the documentation).

  Args:
    trained_var:
      `dict` object with keys contains "a", "mu", and "zeta", and values being
      either numpy arraies or TensorFlow tensors (`tf.constant`), as the value
      of the trained value of variables in "nn4post".

  Returns:
    An instance of `Mixture`.
  """

  var_names = ['a', 'mu', 'zeta']
  for name in var_names:
    if name not in trained_var.keys():
      e = (
          '{0} is not in the keys of {1}.'
      ).format(name, trained_var)
      raise Exception(e)

  _trained_var = {
      name:
          val if isinstance(val, tf.Tensor) \
          else tf.constant(val)
      for name, val in trained_var.items()
  }

  cat = Categorical(tf.nn.softmax(_trained_var['a']))
  mu_zetas = list(zip(
      tf.unstack(_trained_var['mu'], axis=0),
      tf.unstack(_trained_var['zeta'], axis=0),
  ))
  components = [
      Independent(
          NormalWithSoftplusScale(mu, zeta)
      ) for mu, zeta in mu_zetas
  ]
  mixture = Mixture(cat, components)

  return mixture
Exemplo n.º 2
0
def get_trained_posterior(trained_var, param_shape):
  """
  Args:
    trained_var:
      `dict` object with keys contains "a", "mu", and "zeta", and values being
      either numpy arraies or TensorFlow tensors (`tf.constant`), as the value
      of the trained value of variables in "nn4post".
	
	param_shape:
      `dict` with keys the parameter-names and values the assocated shapes (as
      lists).

  Returns:
	Dictionary with keys the parameter-names and values instances of `Mixture`
    as the distributions that fit the associated posteriors.
  """

  n_c = trained_var['a'].shape[0]
  cat = Categorical(logits=trained_var['a'])

  parse_param = get_parse_param(param_shape)
  mu_list = [parse_param(trained_var['mu'][i]) for i in range(n_c)]
  zeta_list = [parse_param(trained_var['zeta'][i]) for i in range(n_c)]

  trained_posterior = {}

  for param_name in trained_var.keys():

    components = [
        Independent(NormalWithSoftplusScale(
            mu_list[i][param_name], zeta_list[i][param_name]))
        for i in range(n_c)
    ]
    mixture = Mixture(cat, components)
    trained_posterior[param_name] = mixture

  return trained_posterior
Exemplo n.º 3
0
 def __init__(self, n_params, scale_offset=0., *args, **kwargs):
     super(ParametrisedGaussian, self).__init__(self.__class__.__name__)
     self._n_params = n_params
     self._scale_offset = scale_offset
     self._create_distrib = lambda x, y: NormalWithSoftplusScale(
         x, y, *args, **kwargs)
Exemplo n.º 4
0
        `dict`, like `{'y': Y}`, where `Y` is an instance of
        `tf.distributions.Distribution`.
    """
    # shape: `[None, n_hiddens]`
    hidden = tf.sigmoid(tf.matmul(input_['x'], param['w_h']) + param['b_h'])
    # shape: `[None, n_outputs]`
    logits = tf.matmul(hidden, param['w_a']) + param['b_a']

    Y = Categorical(logits=logits)
    return {'y': Y}


# PRIOR
with tf.name_scope('prior'):
    w_h = NormalWithSoftplusScale(loc=tf.zeros([n_inputs, n_hiddens]),
                                  scale=tf.ones([n_inputs, n_hiddens]) * 10,
                                  name="w_h")
    w_a = NormalWithSoftplusScale(loc=tf.zeros([n_hiddens, n_outputs]),
                                  scale=tf.ones([n_hiddens, n_outputs]) * 10,
                                  name="w_a")
    b_h = NormalWithSoftplusScale(loc=tf.zeros([n_hiddens]),
                                  scale=tf.ones([n_hiddens]) * 100,
                                  name="b_h")
    b_a = NormalWithSoftplusScale(loc=tf.zeros([n_outputs]),
                                  scale=tf.ones([n_outputs]) * 100,
                                  name="b_a")

param_prior = {
    'w_h': w_h,
    'w_a': w_a,
    'b_h': b_h,
Exemplo n.º 5
0
    def _build(self, inpt, state):
        """Input is unused; it's only to force a maximum number of steps"""

        img_flat, canvas_flat, what_code, where_code, hidden_state, presence = state

        img_inpt = img_flat
        img = tf.reshape(img_inpt, (-1, ) + tuple(self._img_size))

        inpt_encoding = self._input_encoder(img)
        with tf.variable_scope('rnn_inpt'):
            hidden_output, hidden_state = self._transition(
                inpt_encoding, hidden_state)

        where_param = self._transform_estimator(hidden_output)
        where_distrib = NormalWithSoftplusScale(
            *where_param,
            validate_args=self._debug,
            allow_nan_stats=not self._debug)
        where_loc, where_scale = where_distrib.loc, where_distrib.scale
        where_code = where_distrib.sample()

        cropped = self._spatial_transformer(img, where_code)

        with tf.variable_scope('presence'):
            presence_prob = self._steps_predictor(hidden_output)

            if self._explore_eps is not None:
                presence_prob = self._explore_eps / 2 + (
                    1 - self._explore_eps) * presence_prob

            if self._sample_presence:
                presence_distrib = Bernoulli(probs=presence_prob,
                                             dtype=tf.float32,
                                             validate_args=self._debug,
                                             allow_nan_stats=not self._debug)

                new_presence = presence_distrib.sample()
                presence *= new_presence

            else:
                presence = presence_prob

        what_params = self._glimpse_encoder(cropped)
        what_distrib = self._what_distrib(what_params)
        what_loc, what_scale = what_distrib.loc, what_distrib.scale
        what_code = what_distrib.sample()

        decoded = self._glimpse_decoder(what_code)
        inversed = self._inverse_transformer(decoded, where_code)

        with tf.variable_scope('rnn_outputs'):
            inversed_flat = tf.reshape(inversed, (-1, self._n_pix))

            canvas_flat += presence * inversed_flat
            decoded_flat = tf.reshape(decoded, (-1, np.prod(self._crop_size)))

        output = [
            canvas_flat, decoded_flat, what_code, what_loc, what_scale,
            where_code, where_loc, where_scale, presence_prob, presence
        ]
        state = [
            img_flat, canvas_flat, what_code, where_code, hidden_state,
            presence
        ]
        return output, state
Exemplo n.º 6
0
DTYPE = 'float32'
SKIP_STEP = 50

# -- Gaussian Mixture Distribution
with tf.name_scope('posterior'):

    target_c = tf.constant([0.05, 0.25, 0.70])
    target_mu = tf.stack(
        [tf.ones([N_D]) * (i - 1) * 3 for i in range(TARGET_N_C)], axis=0)
    target_zeta_val = np.zeros([TARGET_N_C, N_D])
    #target_zeta_val[1] = np.ones([N_D]) * 5.0
    target_zeta = tf.constant(target_zeta_val, dtype='float32')

    cat = Categorical(probs=target_c)
    components = [
        Independent(NormalWithSoftplusScale(target_mu[i], target_zeta[i]))
        for i in range(TARGET_N_C)
    ]
    p = Mixture(cat, components)

    def log_posterior(theta):
        return p.log_prob(theta)


# test!
# test 1
init_var = {
    'a':
        np.zeros([N_C], dtype=DTYPE),
    'mu':
        np.array([np.ones([N_D]) * (i - 1) * 3 for i in range(N_C)],
Exemplo n.º 7
0
    def _build(self, inpt, state):

        img_flat, canvas_flat, what_code, where_code, hidden_state, presence = state
        img = tf.reshape(img_flat, (-1, ) + tuple(self._img_size))

        inpt_encoding = img
        inpt_encoding = self._input_encoder(inpt_encoding)

        with tf.variable_scope('rnn_inpt'):
            rnn_inpt = tf.concat(
                (inpt_encoding, what_code, where_code, presence), -1)
            rnn_inpt = self._rnn_projection(rnn_inpt)
            hidden_output, hidden_state = self._transition(
                rnn_inpt, hidden_state)

        where_param = self._transform_estimator(hidden_output)
        where_distrib = NormalWithSoftplusScale(
            *where_param,
            validate_args=self._debug,
            allow_nan_stats=not self._debug)
        where_loc, where_scale = where_distrib.loc, where_distrib.scale
        where_code = where_distrib.sample()

        cropped = self._spatial_transformer(img, where_code)

        with tf.variable_scope('presence'):
            presence_prob = self._steps_predictor(hidden_output)

            if self._explore_eps is not None:
                clipped_prob = tf.clip_by_value(presence_prob,
                                                self._explore_eps,
                                                1. - self._explore_eps)
                presence_prob = tf.stop_gradient(clipped_prob -
                                                 presence_prob) + presence_prob

            if self._sample_presence:
                presence_distrib = Bernoulli(probs=presence_prob,
                                             dtype=tf.float32,
                                             validate_args=self._debug,
                                             allow_nan_stats=not self._debug)

                new_presence = presence_distrib.sample()
                presence *= new_presence

            else:
                presence = presence_prob

        what_params = self._glimpse_encoder(cropped)
        what_distrib = self._what_distrib(what_params)
        what_loc, what_scale = what_distrib.loc, what_distrib.scale
        what_code = what_distrib.sample()
        decoded = self._glimpse_decoder(
            tf.concat([what_code, tf.stop_gradient(where_code)], -1))
        inversed = self._inverse_transformer(decoded, where_code)

        with tf.variable_scope('rnn_outputs'):
            inversed_flat = tf.reshape(inversed, (-1, self._n_pix))

            canvas_flat = canvas_flat + presence * inversed_flat  # * novelty_flat
            decoded_flat = tf.reshape(decoded, (-1, np.prod(self._crop_size)))

        output = [
            canvas_flat, decoded_flat, what_code, what_loc, what_scale,
            where_code, where_loc, where_scale, presence_prob, presence
        ]
        state = [
            img_flat, canvas_flat, what_code, where_code, hidden_state,
            presence
        ]
        return output, state