def reward_prediction_big( self, input_images, input_reward, action, latent, mid_outputs): """Builds a reward prediction network.""" del mid_outputs conv_size = self.tinyify([32, 32, 16, 8]) with tf.variable_scope("reward_pred", reuse=tf.AUTO_REUSE): x = tf.concat(input_images, axis=3) x = tfcl.layer_norm(x) if not self.hparams.small_mode: x = tfl.conv2d(x, conv_size[1], [3, 3], strides=(2, 2), activation=tf.nn.relu, name="reward_conv1") x = tfcl.layer_norm(x) # Inject additional inputs if action is not None: x = common_video.inject_additional_input( x, action, "action_enc", self.hparams.action_injection) if input_reward is not None: x = common_video.inject_additional_input(x, input_reward, "reward_enc") if latent is not None: latent = tfl.flatten(latent) latent = tf.expand_dims(latent, axis=1) latent = tf.expand_dims(latent, axis=1) x = common_video.inject_additional_input(x, latent, "latent_enc") x = tfl.conv2d(x, conv_size[2], [3, 3], strides=(2, 2), activation=tf.nn.relu, name="reward_conv2") x = tfcl.layer_norm(x) x = tfl.conv2d(x, conv_size[3], [3, 3], strides=(2, 2), activation=tf.nn.relu, name="reward_conv3")
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) activation_fn = common_layers.belu if hparams.activation_fn == "relu": activation_fn = tf.nn.relu def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input(x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) # Embed the inputs. stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if self.has_actions: action = actions[-1] x = common_video.inject_additional_input(x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. x, extra_loss = self.inject_latent(x, frames, target_frame) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") # No reward prediction if not needed. if not self.has_rewards: return x, None, extra_loss, internal_states # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu( tf.layers.dense(reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, extra_loss, internal_states
def bottom_part_tower(self, input_image, input_reward, action, latent, lstm_state, lstm_size, conv_size, concat_latent=False): """The bottom part of predictive towers. With the current (early) design, the main prediction tower and the reward prediction tower share the same arcitecture. TF Scope can be adjusted as required to either share or not share the weights between the two towers. Args: input_image: the current image. input_reward: the current reward. action: the action taken by the agent. latent: the latent vector. lstm_state: the current internal states of conv lstms. lstm_size: the size of lstms. conv_size: the size of convolutions. concat_latent: whether or not to concatenate the latent at every step. Returns: - the output of the partial network. - intermidate outputs for skip connections. """ lstm_func = common_video.conv_lstm_2d tile_and_concat = common_video.tile_and_concat input_image = common_layers.make_even_size(input_image) concat_input_image = tile_and_concat(input_image, latent, concat_latent=concat_latent) layer_id = 0 enc0 = tfl.conv2d(concat_input_image, conv_size[0], [5, 5], strides=(2, 2), activation=tf.nn.relu, padding="SAME", name="scale1_conv1") enc0 = tfcl.layer_norm(enc0, scope="layer_norm1") hidden1, lstm_state[layer_id] = lstm_func(enc0, lstm_state[layer_id], lstm_size[layer_id], name="state1") hidden1 = tile_and_concat(hidden1, latent, concat_latent=concat_latent) hidden1 = tfcl.layer_norm(hidden1, scope="layer_norm2") layer_id += 1 hidden2, lstm_state[layer_id] = lstm_func(hidden1, lstm_state[layer_id], lstm_size[layer_id], name="state2") hidden2 = tfcl.layer_norm(hidden2, scope="layer_norm3") hidden2 = common_layers.make_even_size(hidden2) enc1 = tfl.conv2d(hidden2, hidden2.get_shape()[3], [3, 3], strides=(2, 2), padding="SAME", activation=tf.nn.relu, name="conv2") enc1 = tile_and_concat(enc1, latent, concat_latent=concat_latent) layer_id += 1 if self.hparams.small_mode: hidden4, enc2 = hidden2, enc1 else: hidden3, lstm_state[layer_id] = lstm_func(enc1, lstm_state[layer_id], lstm_size[layer_id], name="state3") hidden3 = tile_and_concat(hidden3, latent, concat_latent=concat_latent) hidden3 = tfcl.layer_norm(hidden3, scope="layer_norm4") layer_id += 1 hidden4, lstm_state[layer_id] = lstm_func(hidden3, lstm_state[layer_id], lstm_size[layer_id], name="state4") hidden4 = tile_and_concat(hidden4, latent, concat_latent=concat_latent) hidden4 = tfcl.layer_norm(hidden4, scope="layer_norm5") hidden4 = common_layers.make_even_size(hidden4) enc2 = tfl.conv2d(hidden4, hidden4.get_shape()[3], [3, 3], strides=(2, 2), padding="SAME", activation=tf.nn.relu, name="conv3") layer_id += 1 if action is not None: enc2 = common_video.inject_additional_input( enc2, action, "action_enc", self.hparams.action_injection) if input_reward is not None: enc2 = common_video.inject_additional_input( enc2, input_reward, "reward_enc") if latent is not None and not concat_latent: with tf.control_dependencies([latent]): enc2 = tf.concat([enc2, latent], axis=3) enc3 = tfl.conv2d(enc2, hidden4.get_shape()[3], [1, 1], strides=(1, 1), padding="SAME", activation=tf.nn.relu, name="conv4") hidden5, lstm_state[layer_id] = lstm_func(enc3, lstm_state[layer_id], lstm_size[layer_id], name="state5") hidden5 = tfcl.layer_norm(hidden5, scope="layer_norm6") hidden5 = tile_and_concat(hidden5, latent, concat_latent=concat_latent) layer_id += 1 return hidden5, (enc0, enc1), layer_id
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) action = actions[-1] # Stack the inputs. if internal_states is not None and hparams.concat_internal_states: # Use the first part of the first internal state if asked to concatenate. batch_size = common_layers.shape_list(frames[0])[0] internal_state = internal_states[0][0][:batch_size, :, :, :] stacked_frames = tf.concat(frames + [internal_state], axis=-1) else: stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Update internal states early if requested. if hparams.concat_internal_states: internal_states = self.update_internal_states_early( internal_states, frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. x, extra_loss = self.inject_latent(x, frames, target_frame, action) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose( x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") # No reward prediction if not needed. if not self.has_rewards: return x, None, extra_loss, internal_states # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu(tf.layers.dense( reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, extra_loss, internal_states
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) action = actions[-1] activation_fn = common_layers.belu if self.hparams.activation_fn == "relu": activation_fn = tf.nn.relu # Normalize frames. frames = [common_layers.standardize_images(f) for f in frames] # Stack the inputs. if internal_states is not None and hparams.concat_internal_states: # Use the first part of the first internal state if asked to concatenate. batch_size = common_layers.shape_list(frames[0])[0] internal_state = internal_states[0][0][:batch_size, :, :, :] stacked_frames = tf.concat(frames + [internal_state], axis=-1) else: stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Update internal states early if requested. if hparams.concat_internal_states: internal_states = self.update_internal_states_early( internal_states, frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) if self.has_actions: with tf.variable_scope("policy"): x_flat = tf.layers.flatten(x) policy_pred = tf.layers.dense(x_flat, self.hparams.problem.num_actions) value_pred = tf.layers.dense(x_flat, 1) value_pred = tf.squeeze(value_pred, axis=-1) else: policy_pred, value_pred = None, None # Add embedded action if present. if self.has_actions: x = common_video.inject_additional_input(x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. norm_target_frame = common_layers.standardize_images(target_frame) x, extra_loss = self.inject_latent(x, frames, norm_target_frame, action) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=activation_fn, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if hparams.do_autoregressive_rnn: # If enabled, we predict the target frame autoregregressively using rnns. # To this end, the current prediciton is flattened into one long sequence # of sub-pixels, and so is the target frame. Each sub-pixel (RGB value, # from 0 to 255) is predicted with an RNN. To avoid doing as many steps # as width * height * channels, we only use a number of pixels back, # as many as hparams.autoregressive_rnn_lookback. with tf.variable_scope("autoregressive_rnn"): batch_size = common_layers.shape_list(frames[0])[0] # Height, width, channels and lookback are the constants we need. h, w = inputs_shape[1], inputs_shape[ 2] # 105, 80 on Atari games c = hparams.problem.num_channels lookback = hparams.autoregressive_rnn_lookback assert ( h * w ) % lookback == 0, "Number of pixels must divide lookback." m = (h * w) // lookback # Batch size multiplier for the RNN. # These are logits that will be used as inputs to the RNN. rnn_inputs = tf.layers.dense(x, c * 64, name="rnn_inputs") # They are of shape [batch_size, h, w, c, 64], reshaping now. rnn_inputs = tf.reshape(rnn_inputs, [batch_size * m, lookback * c, 64]) # Same for the target frame. rnn_target = tf.reshape(target_frame, [batch_size * m, lookback * c]) # Construct rnn starting state: flatten rnn_inputs, apply a relu layer. rnn_start_state = tf.nn.relu( tf.layers.dense(tf.nn.relu(tf.layers.flatten(rnn_inputs)), 256, name="rnn_start_state")) # Our RNN function API is on bits, each subpixel has 8 bits. total_num_bits = lookback * c * 8 # We need to provide RNN targets as bits (due to the API). rnn_target_bits = discretization.int_to_bit(rnn_target, 8) rnn_target_bits = tf.reshape(rnn_target_bits, [batch_size * m, total_num_bits]) if self.is_training: # Run the RNN in training mode, add it's loss to the losses. rnn_predict, rnn_loss = discretization.predict_bits_with_lstm( rnn_start_state, 128, total_num_bits, target_bits=rnn_target_bits, extra_inputs=rnn_inputs) extra_loss += rnn_loss # We still use non-RNN predictions too in order to guide the network. x = tf.layers.dense(x, c * 256, name="logits") x = tf.reshape(x, [batch_size, h, w, c, 256]) rnn_predict = tf.reshape(rnn_predict, [batch_size, h, w, c, 256]) # Mix non-RNN and RNN predictions so that after warmup the RNN is 90%. x = tf.reshape(tf.nn.log_softmax(x), [batch_size, h, w, c * 256]) rnn_predict = tf.nn.log_softmax(rnn_predict) rnn_predict = tf.reshape(rnn_predict, [batch_size, h, w, c * 256]) alpha = 0.9 * common_layers.inverse_lin_decay( hparams.autoregressive_rnn_warmup_steps) x = alpha * rnn_predict + (1.0 - alpha) * x else: # In prediction mode, run the RNN without any targets. bits, _ = discretization.predict_bits_with_lstm( rnn_start_state, 128, total_num_bits, extra_inputs=rnn_inputs, temperature=0.0 ) # No sampling from this RNN, just greedy. # The output is in bits, get back the predicted pixels. bits = tf.reshape(bits, [batch_size * m, lookback * c, 8]) ints = discretization.bit_to_int(tf.maximum(bits, 0), 8) ints = tf.reshape(ints, [batch_size, h, w, c]) x = tf.reshape(tf.one_hot(ints, 256), [batch_size, h, w, c * 256]) elif self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") reward_pred = None if self.has_rewards: # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu( tf.layers.dense(reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
def bottom_part_tower(self, input_image, input_reward, action, latent, lstm_state, lstm_size, conv_size, concat_latent=False): """The bottom part of predictive towers. With the current (early) design, the main prediction tower and the reward prediction tower share the same arcitecture. TF Scope can be adjusted as required to either share or not share the weights between the two towers. Args: input_image: the current image. input_reward: the current reward. action: the action taken by the agent. latent: the latent vector. lstm_state: the current internal states of conv lstms. lstm_size: the size of lstms. conv_size: the size of convolutions. concat_latent: whether or not to concatenate the latent at every step. Returns: - the output of the partial network. - intermidate outputs for skip connections. """ lstm_func = common_video.conv_lstm_2d tile_and_concat = common_video.tile_and_concat input_image = common_layers.make_even_size(input_image) concat_input_image = tile_and_concat( input_image, latent, concat_latent=concat_latent) layer_id = 0 enc0 = tfl.conv2d( concat_input_image, conv_size[0], [5, 5], strides=(2, 2), activation=tf.nn.relu, padding="SAME", name="scale1_conv1") enc0 = tfcl.layer_norm(enc0, scope="layer_norm1") hidden1, lstm_state[layer_id] = lstm_func( enc0, lstm_state[layer_id], lstm_size[layer_id], name="state1") hidden1 = tile_and_concat(hidden1, latent, concat_latent=concat_latent) hidden1 = tfcl.layer_norm(hidden1, scope="layer_norm2") layer_id += 1 hidden2, lstm_state[layer_id] = lstm_func( hidden1, lstm_state[layer_id], lstm_size[layer_id], name="state2") hidden2 = tfcl.layer_norm(hidden2, scope="layer_norm3") hidden2 = common_layers.make_even_size(hidden2) enc1 = tfl.conv2d(hidden2, hidden2.get_shape()[3], [3, 3], strides=(2, 2), padding="SAME", activation=tf.nn.relu, name="conv2") enc1 = tile_and_concat(enc1, latent, concat_latent=concat_latent) layer_id += 1 if self.hparams.small_mode: hidden4, enc2 = hidden2, enc1 else: hidden3, lstm_state[layer_id] = lstm_func( enc1, lstm_state[layer_id], lstm_size[layer_id], name="state3") hidden3 = tile_and_concat(hidden3, latent, concat_latent=concat_latent) hidden3 = tfcl.layer_norm(hidden3, scope="layer_norm4") layer_id += 1 hidden4, lstm_state[layer_id] = lstm_func( hidden3, lstm_state[layer_id], lstm_size[layer_id], name="state4") hidden4 = tile_and_concat(hidden4, latent, concat_latent=concat_latent) hidden4 = tfcl.layer_norm(hidden4, scope="layer_norm5") hidden4 = common_layers.make_even_size(hidden4) enc2 = tfl.conv2d(hidden4, hidden4.get_shape()[3], [3, 3], strides=(2, 2), padding="SAME", activation=tf.nn.relu, name="conv3") layer_id += 1 if action is not None: enc2 = common_video.inject_additional_input( enc2, action, "action_enc", self.hparams.action_injection) if input_reward is not None: enc2 = common_video.inject_additional_input( enc2, input_reward, "reward_enc") if latent is not None and not concat_latent: with tf.control_dependencies([latent]): enc2 = tf.concat([enc2, latent], axis=3) enc3 = tfl.conv2d(enc2, hidden4.get_shape()[3], [1, 1], strides=(1, 1), padding="SAME", activation=tf.nn.relu, name="conv4") hidden5, lstm_state[layer_id] = lstm_func( enc3, lstm_state[layer_id], lstm_size[layer_id], name="state5") hidden5 = tfcl.layer_norm(hidden5, scope="layer_norm6") hidden5 = tile_and_concat(hidden5, latent, concat_latent=concat_latent) layer_id += 1 return hidden5, (enc0, enc1), layer_id
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input( x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss