def inner_loop( i, hit_eos, next_id, next_id_tag, decoded_ids, decoded_ids_tag, cache, log_prob, ): """One step of greedy decoding.""" logits, logits_tag, cache = symbols_to_logits_fn( next_id, next_id_tag, i, cache) log_probs = common_layers.log_prob_from_logits(logits) temperature = sampling_temperature if hparams.sampling_method == 'random_per_example': next_id = common_layers.sample_temperature_per_example( logits, temperature, top_k) else: if hparams.sampling_method == 'argmax': temperature = 0.0 next_id = common_layers.sample_with_temperature( logits, temperature, top_k) if hparams.sampling_method == 'random_per_example': next_id_tag = common_layers.sample_temperature_per_example( logits_tag, temperature, top_k) else: if hparams.sampling_method == 'argmax': temperature = 0.0 next_id_tag = common_layers.sample_with_temperature( logits_tag, temperature, top_k) log_prob_indices = tf.stack( [tf.range(tf.to_int64(batch_size)), next_id], axis=1) log_prob += tf.gather_nd( log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos)) hit_eos |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) next_id_tag = tf.expand_dims(next_id_tag, axis=1) decoded_ids_tag = tf.concat([decoded_ids_tag, next_id_tag], axis=1) return ( i + 1, hit_eos, next_id, next_id_tag, decoded_ids, decoded_ids_tag, cache, log_prob, )
def infer(self, features, *args, **kwargs): """Produce predictions from the model by sampling.""" # Inputs and features preparation needed to handle edge cases. if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Sample first. try: num_channels = self.hparams.problem.num_channels except AttributeError: num_channels = 1 if "targets" not in features: features["targets"] = tf.zeros( [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable samples = common_layers.sample_with_temperature( logits, 0.0) shape = common_layers.shape_list(samples) # Sample again if requested for the autoregressive part. extra_samples = self.hparams.autoregressive_decode_steps self.hparams.autoregressive_dropout = 0.2 for i in range(extra_samples): if i == extra_samples - 2: self.hparams.autoregressive_dropout -= 0.1 self.hparams.sampling_temp /= 2 if i == extra_samples - 1: self.hparams.autoregressive_dropout -= 0.1 self.hparams.sampling_temp = 0.0 features["targets"] = samples old_samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) with tf.variable_scope(tf.get_variable_scope(), reuse=True): logits, _ = self(features) # pylint: disable=not-callable samples = common_layers.sample_with_temperature( logits, self.hparams.sampling_temp) samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) samples1d = tf.concat([old_samples1d[:, :i, :], samples1d[:, i:, :]], axis=1) samples = tf.reshape(samples1d, shape) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: features["inputs"] = inputs_old # Return samples. return samples
def infer(self, features, *args, **kwargs): """Produce predictions from the model by sampling.""" # Inputs and features preparation needed to handle edge cases. if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Sample first. try: num_channels = self.hparams.problem.num_channels except AttributeError: num_channels = 1 if "targets" not in features: features["targets"] = tf.zeros( [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable samples = common_layers.sample_with_temperature(logits, 0.0) shape = common_layers.shape_list(samples) # Sample again if requested for the autoregressive part. extra_samples = self.hparams.autoregressive_decode_steps self.hparams.autoregressive_dropout = 0.2 for i in range(extra_samples): if i == extra_samples - 2: self.hparams.autoregressive_dropout -= 0.1 self.hparams.sampling_temp /= 2 if i == extra_samples - 1: self.hparams.autoregressive_dropout -= 0.1 self.hparams.sampling_temp = 0.0 features["targets"] = samples old_samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) with tf.variable_scope(tf.get_variable_scope(), reuse=True): logits, _ = self(features) # pylint: disable=not-callable samples = common_layers.sample_with_temperature( logits, self.hparams.sampling_temp) samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) samples1d = tf.concat( [old_samples1d[:, :i, :], samples1d[:, i:, :]], axis=1) samples = tf.reshape(samples1d, shape) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: features["inputs"] = inputs_old # Return samples. return samples
def __init__( self, batch_size, observation_space, action_space, policy_hparams, policy_dir, sampling_temp ): super(PolicyAgent, self).__init__( batch_size, observation_space, action_space ) self._sampling_temp = sampling_temp with tf.Graph().as_default(): self._observations_t = tf.placeholder( shape=((batch_size,) + self.observation_space.shape), dtype=self.observation_space.dtype ) (logits, self._values_t) = rl.get_policy( self._observations_t, policy_hparams, self.action_space ) actions = common_layers.sample_with_temperature(logits, sampling_temp) self._probs_t = tf.nn.softmax(logits / sampling_temp) self._actions_t = tf.cast(actions, tf.int32) model_saver = tf.train.Saver( tf.global_variables(policy_hparams.policy_network + "/.*") # pylint: disable=unexpected-keyword-arg ) self._sess = tf.Session() self._sess.run(tf.global_variables_initializer()) trainer_lib.restore_checkpoint(policy_dir, model_saver, self._sess)
def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims( common_layers.sample_with_temperature(logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache
def inner_loop(i, finished, next_id, decoded_ids, cache): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature(logits, temperature) finished |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, finished, next_id, decoded_ids, cache
def inner_loop(cache_flag, i, finished, next_id, decoded_ids, cache): """One step of greedy decoding.""" logits, cache, out = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature( logits, temperature) finished |= tf.equal(next_id, eos_id) next_id = tf.expand_dims(next_id, axis=1) cache_flag = tf.py_func(sentence_cache.AddMultipleEntries, [next_id, out], tf.int64) cache_flag.set_shape(tf.TensorShape([])) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return cache_flag, i + 1, finished, next_id, decoded_ids, cache
def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob): """One step of greedy decoding.""" logits, cache = symbols_to_logits_fn(next_id, i, cache) log_probs = common_layers.log_prob_from_logits(logits) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = common_layers.sample_with_temperature(logits, temperature) hit_eos |= tf.equal(next_id, eos_id) log_prob_indices = tf.stack( [tf.range(tf.to_int64(batch_size)), next_id], axis=1) log_prob += tf.gather_nd(log_probs, log_prob_indices) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
def env_step(arg1, arg2, arg3): # pylint: disable=unused-argument """Step of the environment.""" (logits, value_function) = get_policy( obs_copy, ppo_hparams, batch_env.action_space ) action = common_layers.sample_with_temperature(logits, sampling_temp) action = tf.cast(action, tf.int32) reward, done = batch_env.simulate(action[:, 0, ...]) pdf = tfp.distributions.Categorical(logits=logits).prob(action) pdf = tf.reshape(pdf, shape=(num_agents,)) value_function = tf.reshape(value_function, shape=(num_agents,)) done = tf.reshape(done, shape=(num_agents,)) with tf.control_dependencies([reward, done]): return tf.identity(pdf), tf.identity(value_function), \ tf.identity(done)
def pixels_from_softmax(frame_logits, pure_sampling=False, temperature=1.0, gumbel_noise_factor=0.2): """Given frame_logits from a per-pixel softmax, generate colors.""" # If we're purely sampling, just sample each pixel. if pure_sampling or temperature == 0.0: return common_layers.sample_with_temperature(frame_logits, temperature) # Gumbel-sample from the pixel sofmax and average by pixel values. pixel_range = tf.to_float(tf.range(256)) for _ in range(len(frame_logits.get_shape().as_list()) - 1): pixel_range = tf.expand_dims(pixel_range, axis=0) frame_logits = tf.nn.log_softmax(frame_logits) gumbel_samples = discretization.gumbel_sample( common_layers.shape_list(frame_logits)) * gumbel_noise_factor frame = tf.nn.softmax((frame_logits + gumbel_samples) / temperature, axis=-1) result = tf.reduce_sum(frame * pixel_range, axis=-1) # Round on the forward pass, not on the backward one. return result + tf.stop_gradient(tf.round(result) - result)
def inject_latent(self, layer, features, filters): """Inject a deterministic latent based on the target frame.""" del filters hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) batch_size = layer_shape[0] state_size = hparams.latent_predictor_state_size lstm_cell = tf.contrib.rnn.LSTMCell(state_size) discrete_predict = tf.layers.Dense(256, name="discrete_predict") discrete_embed = tf.layers.Dense(state_size, name="discrete_embed") def add_d(layer, d): z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add") layer += z_add return layer if self.is_predicting: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) else: layer_pred = tf.reshape( layer, [batch_size, prod(layer_shape[1:])]) prediction = tf.layers.dense(layer_pred, state_size, name="istate") c_state = tf.layers.dense(layer_pred, state_size, name="cstate") m_state = tf.layers.dense(layer_pred, state_size, name="mstate") state = (c_state, m_state) outputs = [] for i in range(hparams.bottleneck_bits // 8): output, state = lstm_cell(prediction, state) discrete_logits = discrete_predict(output) discrete_samples = common_layers.sample_with_temperature( discrete_logits, hparams.latent_predictor_temperature) outputs.append(tf.expand_dims(discrete_samples, axis=1)) prediction = discrete_embed( tf.one_hot(discrete_samples, 256)) outputs = tf.concat(outputs, axis=1) outputs = discretization.int_to_bit(outputs, 8) rand = tf.reshape(outputs, [batch_size, 1, 1, hparams.bottleneck_bits]) d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 return add_d(layer, d), 0.0 # Embed. frames = tf.concat([features["cur_target_frame"], features["inputs"]], axis=-1) x = tf.layers.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck") x0 = tf.tanh(x) d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 - x0) pred_loss = 0.0 if not hparams.full_latent_tower: d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0), [batch_size, hparams.bottleneck_bits // 8, 8]) d_int = discretization.bit_to_int(d_pred, 8) tf.summary.histogram("d_int", tf.reshape(d_int, [-1])) d_hot = tf.one_hot(d_int, 256, axis=-1) d_pred = discrete_embed(d_hot) layer_pred = tf.reshape(layer, [batch_size, prod(layer_shape[1:])]) prediction0 = tf.layers.dense(layer_pred, state_size, name="istate") c_state = tf.layers.dense(layer_pred, state_size, name="cstate") m_state = tf.layers.dense(layer_pred, state_size, name="mstate") pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred], axis=1) state = (c_state, m_state) outputs = [] for i in range(hparams.bottleneck_bits // 8): output, state = lstm_cell(pred[:, i, :], state) outputs.append(tf.expand_dims(output, axis=1)) outputs = tf.concat(outputs, axis=1) d_int_pred = discrete_predict(outputs) pred_loss = tf.losses.sparse_softmax_cross_entropy( logits=d_int_pred, labels=d_int) pred_loss = tf.reduce_mean(pred_loss) if hparams.mode == tf.estimator.ModeKeys.TRAIN: x += tf.truncated_normal(common_layers.shape_list(x), mean=0.0, stddev=0.2) x = tf.tanh(x) noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0 x *= noise d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps) d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x) return add_d(layer, d), pred_loss