Пример #1
0
        def inner_loop(
            i,
            hit_eos,
            next_id,
            next_id_tag,
            decoded_ids,
            decoded_ids_tag,
            cache,
            log_prob,
        ):
            """One step of greedy decoding."""
            logits, logits_tag, cache = symbols_to_logits_fn(
                next_id, next_id_tag, i, cache)
            log_probs = common_layers.log_prob_from_logits(logits)
            temperature = sampling_temperature
            if hparams.sampling_method == 'random_per_example':
                next_id = common_layers.sample_temperature_per_example(
                    logits, temperature, top_k)
            else:
                if hparams.sampling_method == 'argmax':
                    temperature = 0.0
                next_id = common_layers.sample_with_temperature(
                    logits, temperature, top_k)

            if hparams.sampling_method == 'random_per_example':
                next_id_tag = common_layers.sample_temperature_per_example(
                    logits_tag, temperature, top_k)
            else:
                if hparams.sampling_method == 'argmax':
                    temperature = 0.0
                next_id_tag = common_layers.sample_with_temperature(
                    logits_tag, temperature, top_k)

            log_prob_indices = tf.stack(
                [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
            log_prob += tf.gather_nd(
                log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos))
            hit_eos |= tf.equal(next_id, eos_id)

            next_id = tf.expand_dims(next_id, axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            next_id_tag = tf.expand_dims(next_id_tag, axis=1)
            decoded_ids_tag = tf.concat([decoded_ids_tag, next_id_tag], axis=1)

            return (
                i + 1,
                hit_eos,
                next_id,
                next_id_tag,
                decoded_ids,
                decoded_ids_tag,
                cache,
                log_prob,
            )
Пример #2
0
  def infer(self, features, *args, **kwargs):
    """Produce predictions from the model by sampling."""
    # Inputs and features preparation needed to handle edge cases.
    if not features:
      features = {}
    inputs_old = None
    if "inputs" in features and len(features["inputs"].shape) < 4:
      inputs_old = features["inputs"]
      features["inputs"] = tf.expand_dims(features["inputs"], 2)

    # Sample first.
    try:
      num_channels = self.hparams.problem.num_channels
    except AttributeError:
      num_channels = 1
    if "targets" not in features:
      features["targets"] = tf.zeros(
          [self.hparams.batch_size, 1, 1, num_channels],
          dtype=tf.int32)
    logits, _ = self(features)  # pylint: disable=not-callable
    samples = common_layers.sample_with_temperature(
        logits, 0.0)
    shape = common_layers.shape_list(samples)

    # Sample again if requested for the autoregressive part.
    extra_samples = self.hparams.autoregressive_decode_steps
    self.hparams.autoregressive_dropout = 0.2
    for i in range(extra_samples):
      if i == extra_samples - 2:
        self.hparams.autoregressive_dropout -= 0.1
        self.hparams.sampling_temp /= 2
      if i == extra_samples - 1:
        self.hparams.autoregressive_dropout -= 0.1
        self.hparams.sampling_temp = 0.0
      features["targets"] = samples
      old_samples1d = tf.reshape(samples, [shape[0], -1, shape[3]])
      with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        logits, _ = self(features)  # pylint: disable=not-callable
        samples = common_layers.sample_with_temperature(
            logits, self.hparams.sampling_temp)
        samples1d = tf.reshape(samples, [shape[0], -1, shape[3]])
        samples1d = tf.concat([old_samples1d[:, :i, :], samples1d[:, i:, :]],
                              axis=1)
        samples = tf.reshape(samples1d, shape)

    # Restore inputs to not confuse Estimator in edge cases.
    if inputs_old is not None:
      features["inputs"] = inputs_old

    # Return samples.
    return samples
Пример #3
0
  def infer(self, features, *args, **kwargs):
    """Produce predictions from the model by sampling."""
    # Inputs and features preparation needed to handle edge cases.
    if not features:
      features = {}
    inputs_old = None
    if "inputs" in features and len(features["inputs"].shape) < 4:
      inputs_old = features["inputs"]
      features["inputs"] = tf.expand_dims(features["inputs"], 2)

    # Sample first.
    try:
      num_channels = self.hparams.problem.num_channels
    except AttributeError:
      num_channels = 1
    if "targets" not in features:
      features["targets"] = tf.zeros(
          [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32)
    logits, _ = self(features)  # pylint: disable=not-callable
    samples = common_layers.sample_with_temperature(logits, 0.0)
    shape = common_layers.shape_list(samples)

    # Sample again if requested for the autoregressive part.
    extra_samples = self.hparams.autoregressive_decode_steps
    self.hparams.autoregressive_dropout = 0.2
    for i in range(extra_samples):
      if i == extra_samples - 2:
        self.hparams.autoregressive_dropout -= 0.1
        self.hparams.sampling_temp /= 2
      if i == extra_samples - 1:
        self.hparams.autoregressive_dropout -= 0.1
        self.hparams.sampling_temp = 0.0
      features["targets"] = samples
      old_samples1d = tf.reshape(samples, [shape[0], -1, shape[3]])
      with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        logits, _ = self(features)  # pylint: disable=not-callable
        samples = common_layers.sample_with_temperature(
            logits, self.hparams.sampling_temp)
        samples1d = tf.reshape(samples, [shape[0], -1, shape[3]])
        samples1d = tf.concat(
            [old_samples1d[:, :i, :], samples1d[:, i:, :]], axis=1)
        samples = tf.reshape(samples1d, shape)

    # Restore inputs to not confuse Estimator in edge cases.
    if inputs_old is not None:
      features["inputs"] = inputs_old

    # Return samples.
    return samples
Пример #4
0
 def __init__(
     self, batch_size, observation_space, action_space, policy_hparams,
     policy_dir, sampling_temp
 ):
   super(PolicyAgent, self).__init__(
       batch_size, observation_space, action_space
   )
   self._sampling_temp = sampling_temp
   with tf.Graph().as_default():
     self._observations_t = tf.placeholder(
         shape=((batch_size,) + self.observation_space.shape),
         dtype=self.observation_space.dtype
     )
     (logits, self._values_t) = rl.get_policy(
         self._observations_t, policy_hparams, self.action_space
     )
     actions = common_layers.sample_with_temperature(logits, sampling_temp)
     self._probs_t = tf.nn.softmax(logits / sampling_temp)
     self._actions_t = tf.cast(actions, tf.int32)
     model_saver = tf.train.Saver(
         tf.global_variables(policy_hparams.policy_network + "/.*")  # pylint: disable=unexpected-keyword-arg
     )
     self._sess = tf.Session()
     self._sess.run(tf.global_variables_initializer())
     trainer_lib.restore_checkpoint(policy_dir, model_saver, self._sess)
Пример #5
0
 def inner_loop(i, next_id, decoded_ids, cache):
   logits, cache = symbols_to_logits_fn(next_id, i, cache)
   temperature = (0.0 if hparams.sampling_method == "argmax"
                  else hparams.sampling_temp)
   next_id = tf.expand_dims(
       common_layers.sample_with_temperature(logits, temperature), axis=1)
   decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
   return i + 1, next_id, decoded_ids, cache
Пример #6
0
 def inner_loop(i, next_id, decoded_ids, cache):
   logits, cache = symbols_to_logits_fn(next_id, i, cache)
   temperature = (0.0 if hparams.sampling_method == "argmax" else
                  hparams.sampling_temp)
   next_id = tf.expand_dims(
       common_layers.sample_with_temperature(logits, temperature), axis=1)
   decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
   return i + 1, next_id, decoded_ids, cache
Пример #7
0
 def inner_loop(i, finished, next_id, decoded_ids, cache):
   """One step of greedy decoding."""
   logits, cache = symbols_to_logits_fn(next_id, i, cache)
   temperature = (0.0 if hparams.sampling_method == "argmax" else
                  hparams.sampling_temp)
   next_id = common_layers.sample_with_temperature(logits, temperature)
   finished |= tf.equal(next_id, eos_id)
   next_id = tf.expand_dims(next_id, axis=1)
   decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
   return i + 1, finished, next_id, decoded_ids, cache
Пример #8
0
 def inner_loop(i, finished, next_id, decoded_ids, cache):
   """One step of greedy decoding."""
   logits, cache = symbols_to_logits_fn(next_id, i, cache)
   temperature = (0.0 if hparams.sampling_method == "argmax" else
                  hparams.sampling_temp)
   next_id = common_layers.sample_with_temperature(logits, temperature)
   finished |= tf.equal(next_id, eos_id)
   next_id = tf.expand_dims(next_id, axis=1)
   decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
   return i + 1, finished, next_id, decoded_ids, cache
Пример #9
0
        def inner_loop(cache_flag, i, finished, next_id, decoded_ids, cache):
            """One step of greedy decoding."""
            logits, cache, out = symbols_to_logits_fn(next_id, i, cache)
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            next_id = common_layers.sample_with_temperature(
                logits, temperature)
            finished |= tf.equal(next_id, eos_id)
            next_id = tf.expand_dims(next_id, axis=1)

            cache_flag = tf.py_func(sentence_cache.AddMultipleEntries,
                                    [next_id, out], tf.int64)
            cache_flag.set_shape(tf.TensorShape([]))

            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            return cache_flag, i + 1, finished, next_id, decoded_ids, cache
Пример #10
0
    def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
      """One step of greedy decoding."""
      logits, cache = symbols_to_logits_fn(next_id, i, cache)
      log_probs = common_layers.log_prob_from_logits(logits)
      temperature = (0.0 if hparams.sampling_method == "argmax" else
                     hparams.sampling_temp)
      next_id = common_layers.sample_with_temperature(logits, temperature)
      hit_eos |= tf.equal(next_id, eos_id)

      log_prob_indices = tf.stack(
          [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
      log_prob += tf.gather_nd(log_probs, log_prob_indices)

      next_id = tf.expand_dims(next_id, axis=1)
      decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
      return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
Пример #11
0
      def env_step(arg1, arg2, arg3):  # pylint: disable=unused-argument
        """Step of the environment."""

        (logits, value_function) = get_policy(
            obs_copy, ppo_hparams, batch_env.action_space
        )
        action = common_layers.sample_with_temperature(logits, sampling_temp)
        action = tf.cast(action, tf.int32)

        reward, done = batch_env.simulate(action[:, 0, ...])

        pdf = tfp.distributions.Categorical(logits=logits).prob(action)
        pdf = tf.reshape(pdf, shape=(num_agents,))
        value_function = tf.reshape(value_function, shape=(num_agents,))
        done = tf.reshape(done, shape=(num_agents,))

        with tf.control_dependencies([reward, done]):
          return tf.identity(pdf), tf.identity(value_function), \
                 tf.identity(done)
Пример #12
0
def pixels_from_softmax(frame_logits, pure_sampling=False,
                        temperature=1.0, gumbel_noise_factor=0.2):
  """Given frame_logits from a per-pixel softmax, generate colors."""
  # If we're purely sampling, just sample each pixel.
  if pure_sampling or temperature == 0.0:
    return common_layers.sample_with_temperature(frame_logits, temperature)

  # Gumbel-sample from the pixel sofmax and average by pixel values.
  pixel_range = tf.to_float(tf.range(256))
  for _ in range(len(frame_logits.get_shape().as_list()) - 1):
    pixel_range = tf.expand_dims(pixel_range, axis=0)

  frame_logits = tf.nn.log_softmax(frame_logits)
  gumbel_samples = discretization.gumbel_sample(
      common_layers.shape_list(frame_logits)) * gumbel_noise_factor

  frame = tf.nn.softmax((frame_logits + gumbel_samples) / temperature, axis=-1)
  result = tf.reduce_sum(frame * pixel_range, axis=-1)
  # Round on the forward pass, not on the backward one.
  return result + tf.stop_gradient(tf.round(result) - result)
Пример #13
0
def pixels_from_softmax(frame_logits, pure_sampling=False,
                        temperature=1.0, gumbel_noise_factor=0.2):
  """Given frame_logits from a per-pixel softmax, generate colors."""
  # If we're purely sampling, just sample each pixel.
  if pure_sampling or temperature == 0.0:
    return common_layers.sample_with_temperature(frame_logits, temperature)

  # Gumbel-sample from the pixel sofmax and average by pixel values.
  pixel_range = tf.to_float(tf.range(256))
  for _ in range(len(frame_logits.get_shape().as_list()) - 1):
    pixel_range = tf.expand_dims(pixel_range, axis=0)

  frame_logits = tf.nn.log_softmax(frame_logits)
  gumbel_samples = discretization.gumbel_sample(
      common_layers.shape_list(frame_logits)) * gumbel_noise_factor

  frame = tf.nn.softmax((frame_logits + gumbel_samples) / temperature, axis=-1)
  result = tf.reduce_sum(frame * pixel_range, axis=-1)
  # Round on the forward pass, not on the backward one.
  return result + tf.stop_gradient(tf.round(result) - result)
Пример #14
0
    def inject_latent(self, layer, features, filters):
        """Inject a deterministic latent based on the target frame."""
        del filters
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        batch_size = layer_shape[0]
        state_size = hparams.latent_predictor_state_size
        lstm_cell = tf.contrib.rnn.LSTMCell(state_size)
        discrete_predict = tf.layers.Dense(256, name="discrete_predict")
        discrete_embed = tf.layers.Dense(state_size, name="discrete_embed")

        def add_d(layer, d):
            z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if self.is_predicting:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
            else:
                layer_pred = tf.reshape(
                    layer, [batch_size, prod(layer_shape[1:])])
                prediction = tf.layers.dense(layer_pred,
                                             state_size,
                                             name="istate")
                c_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="cstate")
                m_state = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="mstate")
                state = (c_state, m_state)
                outputs = []
                for i in range(hparams.bottleneck_bits // 8):
                    output, state = lstm_cell(prediction, state)
                    discrete_logits = discrete_predict(output)
                    discrete_samples = common_layers.sample_with_temperature(
                        discrete_logits, hparams.latent_predictor_temperature)
                    outputs.append(tf.expand_dims(discrete_samples, axis=1))
                    prediction = discrete_embed(
                        tf.one_hot(discrete_samples, 256))
                outputs = tf.concat(outputs, axis=1)
                outputs = discretization.int_to_bit(outputs, 8)
                rand = tf.reshape(outputs,
                                  [batch_size, 1, 1, hparams.bottleneck_bits])
            d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            return add_d(layer, d), 0.0

        # Embed.
        frames = tf.concat([features["cur_target_frame"], features["inputs"]],
                           axis=-1)
        x = tf.layers.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tf.layers.conv2d(x,
                                         filters,
                                         kernel,
                                         activation=common_layers.belu,
                                         strides=(2, 2),
                                         padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)
        x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")
        x0 = tf.tanh(x)
        d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 -
                                  x0)
        pred_loss = 0.0
        if not hparams.full_latent_tower:
            d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0),
                                [batch_size, hparams.bottleneck_bits // 8, 8])
            d_int = discretization.bit_to_int(d_pred, 8)
            tf.summary.histogram("d_int", tf.reshape(d_int, [-1]))
            d_hot = tf.one_hot(d_int, 256, axis=-1)
            d_pred = discrete_embed(d_hot)
            layer_pred = tf.reshape(layer, [batch_size, prod(layer_shape[1:])])
            prediction0 = tf.layers.dense(layer_pred,
                                          state_size,
                                          name="istate")
            c_state = tf.layers.dense(layer_pred, state_size, name="cstate")
            m_state = tf.layers.dense(layer_pred, state_size, name="mstate")
            pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred],
                             axis=1)
            state = (c_state, m_state)
            outputs = []
            for i in range(hparams.bottleneck_bits // 8):
                output, state = lstm_cell(pred[:, i, :], state)
                outputs.append(tf.expand_dims(output, axis=1))
            outputs = tf.concat(outputs, axis=1)
            d_int_pred = discrete_predict(outputs)
            pred_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=d_int_pred, labels=d_int)
            pred_loss = tf.reduce_mean(pred_loss)
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            x += tf.truncated_normal(common_layers.shape_list(x),
                                     mean=0.0,
                                     stddev=0.2)
            x = tf.tanh(x)
            noise = tf.random_uniform(common_layers.shape_list(x))
            noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise,
                                              noise)) - 1.0
            x *= noise
            d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 -
                                     x)
            p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps)
            d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x)
        return add_d(layer, d), pred_loss