def inject_latent(self, layer, features, filters): """Inject a deterministic latent based on the target frame.""" del filters hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) if hparams.mode == tf.estimator.ModeKeys.PREDICT: layer_shape = common_layers.shape_list(layer) if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) else: rand = tf.random_uniform(layer_shape[:-3] + [1, 1, hparams.bottleneck_bits]) d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 z = tf.layers.dense(d, final_filters, name="unbottleneck") return layer + z, 0.0 # Embed. frames = tf.concat([features["cur_target_frame"], features["inputs"]], axis=-1) x = tf.layers.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) x = tf.tanh( tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")) d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) if hparams.mode == tf.estimator.ModeKeys.TRAIN: noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0 d *= noise z = tf.layers.dense(d, final_filters, name="unbottleneck") return layer + z, 0.0
def add_pos_signals(x, hparams, name="pos_emb"): with tf.variable_scope(name, reuse=False): if hparams.pos == "timing": x = common_attention.add_timing_signal_nd(x) else: assert hparams.pos == "emb" x = common_attention.add_positional_embedding_nd( x, hparams.max_length, name=name) return x
def add_pos_signals(x, hparams, name="pos_emb"): with tf.variable_scope(name, reuse=False): if hparams.pos == "timing": x = common_attention.add_timing_signal_nd(x) else: assert hparams.pos == "emb" x = common_attention.add_positional_embedding_nd( x, hparams.max_length, name) return x
def decoder(self, x, encoder_layers=None): with tf.variable_scope("decoder"): hparams = self.hparams is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Up-convolutions. for i in range(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 if is_training: nomix_p = common_layers.inverse_lin_decay( int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_%d" % j, nomix_p) filters = hparams.hidden_size * 2**j filters = min(filters, hparams.max_hidden_size) with tf.variable_scope("layer_%d" % i): j = hparams.num_hidden_layers - i - 1 x = tf.layers.conv2d_transpose( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout) x = common_layers.layer_norm(x, name="ln") x = common_attention.add_timing_signal_nd(x) if encoder_layers is not None: enc_x = encoder_layers[j] enc_shape = common_layers.shape_list(enc_x) x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :] if is_training: # Mix at the beginning of training. rand = tf.random_uniform(common_layers.shape_list(x_mix)) x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x) x = x_mix return x
def embed(self, x): """Input embedding with a non-zero bias for uniform inputs.""" with tf.variable_scope("embed", reuse=tf.AUTO_REUSE): x = tf.layers.dense( x, self.hparams.hidden_size, name="embed", activation=common_layers.belu, bias_initializer=tf.random_normal_initializer(stddev=0.01)) return common_attention.add_timing_signal_nd(x)
def embed(self, x): """Input embedding with a non-zero bias for uniform inputs.""" with tf.variable_scope("embed", reuse=tf.AUTO_REUSE): x_shape = common_layers.shape_list(x) # Merge channels and depth before embedding. x = tf.reshape(x, x_shape[:-2] + [x_shape[-2] * x_shape[-1]]) x = tf.layers.dense( x, self.hparams.hidden_size, name="embed", activation=common_layers.belu, bias_initializer=tf.random_normal_initializer(stddev=0.01)) return common_attention.add_timing_signal_nd(x)
def preprocess_inputs(inputs, hidden_size): """Transform input size and add positional encodings.""" if inputs.shape.as_list()[-1] != hidden_size: # Project to proper size inputs = common_layers.conv1d(inputs=inputs, filters=hidden_size, kernel_size=1, activation=None, padding='SAME') net = inputs net = common_attention.add_timing_signal_nd(net) return net
def embed(self, x, name="embedding"): """Input embedding with a non-zero bias for uniform inputs.""" with tf.variable_scope(name, reuse=tf.AUTO_REUSE): x_shape = common_layers.shape_list(x) # Merge channels and depth before embedding. x = tf.reshape(x, x_shape[:-2] + [x_shape[-2] * x_shape[-1]]) x = tf.layers.dense( x, self.hparams.hidden_size, name="embed", activation=common_layers.belu, bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_layers.layer_norm(x, name="ln_embed") return common_attention.add_timing_signal_nd(x)
def encoder(self, x): with tf.variable_scope("encoder"): hparams = self.hparams kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Input embedding with a non-zero bias for uniform inputs. x = tf.layers.dense( x, hparams.hidden_size, name="embed", activation=common_layers.belu, bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): x = self.make_even_size(x) x = self.dropout(x) filters = hparams.hidden_size * 2**(i + 1) filters = min(filters, hparams.max_hidden_size) x = tf.layers.conv2d( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout) x = common_layers.layer_norm(x) return x
def encoder(self, x): with tf.variable_scope("encoder"): hparams = self.hparams layers = [] kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Down-convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): x = self.make_even_size(x) layers.append(x) x = self.dropout(x) filters = hparams.hidden_size * 2**(i + 1) filters = min(filters, hparams.max_hidden_size) x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x y = tf.nn.dropout(y, 1.0 - hparams.residual_dropout) for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += y x = common_layers.layer_norm(x, name="ln") return x, layers
def encoder(self, x): with tf.variable_scope("encoder"): hparams = self.hparams layers = [] kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Down-convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): x = self.make_even_size(x) layers.append(x) x = self.dropout(x) filters = hparams.hidden_size * 2**(i + 1) filters = min(filters, hparams.max_hidden_size) x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x y = tf.nn.dropout(y, 1.0 - hparams.residual_dropout) for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += y x = common_layers.layer_norm(x, name="ln") return x, layers
def body(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) # Embed the inputs. inputs_shape = common_layers.shape_list(features["inputs"]) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( features["inputs"], filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if "input_action" in features: action = tf.reshape(features["input_action"][:, -1, :], [-1, 1, 1, hparams.hidden_size]) action_mask = tf.layers.dense(action, filters, name="action_mask") zeros_mask = tf.zeros(common_layers.shape_list(x)[:-1] + [filters], dtype=tf.float32) if hparams.concatenate_actions: x = tf.concat([x, action_mask + zeros_mask], axis=-1) else: x *= action_mask + zeros_mask x, extra_loss = self.inject_latent(x, features, filters) # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer%d" % i): y = tf.layers.conv2d(x, filters, kernel1, activation=common_layers.belu, strides=(1, 1), padding="SAME") y = tf.nn.dropout(y, 1.0 - hparams.dropout) if i == 0: x = y else: x = common_layers.layer_norm(x + y) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] # Reward prediction if needed. if "target_reward" not in features: return x reward_pred = tf.reduce_mean(x, axis=[1, 2], keepdims=True) return {"targets": x, "target_reward": reward_pred}, extra_loss
def body(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) # Embed the inputs. inputs_shape = common_layers.shape_list(features["inputs"]) x = tf.layers.dense(features["inputs"], filters, name="inputs_embed") # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = self.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action. action = tf.reshape(features["input_action"][:, -1, :], [-1, 1, 1, hparams.hidden_size]) action_mask = tf.layers.dense(action, filters, name="action_mask") zeros_mask = tf.zeros(common_layers.shape_list(x)[:-1] + [filters], dtype=tf.float32) x *= action_mask + zeros_mask # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer%d" % i): y = tf.layers.conv2d(x, filters, kernel1, activation=common_layers.belu, strides=(1, 1), padding="SAME") y = tf.nn.dropout(y, 1.0 - hparams.dropout) if i == 0: x = y else: x = common_layers.layer_norm(x + y) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] # Reward prediction. reward_pred = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) return {"targets": x, "target_reward": reward_pred}
def body(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) # Embed the inputs. inputs_shape = common_layers.shape_list(features["inputs"]) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( features["inputs"], filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if "input_action" in features: action = tf.reshape(features["input_action"][:, -1, :], [-1, 1, 1, hparams.hidden_size]) action_mask = tf.layers.dense(action, filters, name="action_mask") zeros_mask = tf.zeros(common_layers.shape_list(x)[:-1] + [filters], dtype=tf.float32) x *= action_mask + zeros_mask # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer%d" % i): y = tf.layers.conv2d(x, filters, kernel1, activation=common_layers.belu, strides=(1, 1), padding="SAME") y = tf.nn.dropout(y, 1.0 - hparams.dropout) if i == 0: x = y else: x = common_layers.layer_norm(x + y) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose( x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] # Reward prediction if needed. if "target_reward" not in features: return x reward_pred = tf.reduce_mean(x, axis=[1, 2], keepdims=True) return {"targets": x, "target_reward": reward_pred}
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) activation_fn = common_layers.belu if hparams.activation_fn == "relu": activation_fn = tf.nn.relu def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input(x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay( hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) # Embed the inputs. stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if self.has_actions: action = actions[-1] x = common_video.inject_additional_input(x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. x, extra_loss = self.inject_latent(x, frames, target_frame) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") # No reward prediction if not needed. if not self.has_rewards: return x, None, extra_loss, internal_states # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu( tf.layers.dense(reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, extra_loss, internal_states
def body_single(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) # Embed the inputs. inputs_shape = common_layers.shape_list(features["inputs"]) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( features["inputs"], filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if "input_action" in features: action = features["input_action"][:, -1, :] x = self.inject_additional_input(x, action, "action_enc", hparams.action_injection) x, extra_loss = self.inject_latent(x, features, filters) # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer%d" % i): y = tf.nn.dropout(x, 1.0 - hparams.dropout) y = tf.layers.conv2d(y, filters, kernel1, activation=common_layers.belu, strides=(1, 1), padding="SAME") if i == 0: x = y else: x = common_layers.layer_norm(x + y) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): if "input_action" in features: x = self.inject_additional_input(x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] if self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") # Reward prediction if needed. if "target_reward" not in features: return x reward_pred = tf.expand_dims( # Add a fake channels dim. tf.reduce_mean(x, axis=[1, 2], keepdims=True), axis=3) return {"targets": x, "target_reward": reward_pred}, extra_loss
def inject_latent(self, layer, inputs, target, action): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Add embedded action if present. if action is not None: x = common_video.inject_additional_input( x, action, "action_enc_latent", hparams.action_injection) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) # Mix bits from latent with predicted bits on forward pass as a noise. if hparams.latent_rnn_max_sampling > 0.0: with tf.variable_scope(tf.get_variable_scope(), reuse=True): bits_pred, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits_pred = tf.expand_dims(tf.expand_dims(bits_pred, axis=1), axis=2) # Be bits_pred on the forward pass but bits on the backward one. bits_pred = bits_clean + tf.stop_gradient(bits_pred - bits_clean) # Select which bits to take from pred sampling with bit_p probability. which_bit = tf.random_uniform(common_layers.shape_list(bits)) bit_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps) bit_p *= hparams.latent_rnn_max_sampling bits = tf.where(which_bit < bit_p, bits_pred, bits) res = add_bits(layer, bits) # During training, sometimes skip the latent to help action-conditioning. res_p = common_layers.inverse_lin_decay(hparams.latent_rnn_warmup_steps / 2) res_p *= hparams.latent_use_max_probability res_rand = tf.random_uniform([layer_shape[0]]) res = tf.where(res_rand < res_p, res, layer) return res, pred_loss
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) action = actions[-1] # Stack the inputs. if internal_states is not None and hparams.concat_internal_states: # Use the first part of the first internal state if asked to concatenate. batch_size = common_layers.shape_list(frames[0])[0] internal_state = internal_states[0][0][:batch_size, :, :, :] stacked_frames = tf.concat(frames + [internal_state], axis=-1) else: stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Update internal states early if requested. if hparams.concat_internal_states: internal_states = self.update_internal_states_early( internal_states, frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. x, extra_loss = self.inject_latent(x, frames, target_frame, action) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose( x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") # No reward prediction if not needed. if not self.has_rewards: return x, None, extra_loss, internal_states # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu(tf.layers.dense( reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, extra_loss, internal_states
def inject_latent(self, layer, features, filters): """Inject a deterministic latent based on the target frame.""" del filters hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) batch_size = layer_shape[0] state_size = hparams.latent_predictor_state_size lstm_cell = tf.contrib.rnn.LSTMCell(state_size) discrete_predict = tf.layers.Dense(256, name="discrete_predict") discrete_embed = tf.layers.Dense(state_size, name="discrete_embed") def add_d(layer, d): z_mul = tf.layers.dense(d, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tf.layers.dense(d, final_filters, name="unbottleneck_add") layer += z_add return layer if self.is_predicting: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) else: layer_pred = tf.reshape( layer, [batch_size, prod(layer_shape[1:])]) prediction = tf.layers.dense(layer_pred, state_size, name="istate") c_state = tf.layers.dense(layer_pred, state_size, name="cstate") m_state = tf.layers.dense(layer_pred, state_size, name="mstate") state = (c_state, m_state) outputs = [] for i in range(hparams.bottleneck_bits // 8): output, state = lstm_cell(prediction, state) discrete_logits = discrete_predict(output) discrete_samples = common_layers.sample_with_temperature( discrete_logits, hparams.latent_predictor_temperature) outputs.append(tf.expand_dims(discrete_samples, axis=1)) prediction = discrete_embed( tf.one_hot(discrete_samples, 256)) outputs = tf.concat(outputs, axis=1) outputs = discretization.int_to_bit(outputs, 8) rand = tf.reshape(outputs, [batch_size, 1, 1, hparams.bottleneck_bits]) d = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 return add_d(layer, d), 0.0 # Embed. frames = tf.concat([features["cur_target_frame"], features["inputs"]], axis=-1) x = tf.layers.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck") x0 = tf.tanh(x) d = x0 + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x0)) - 1.0 - x0) pred_loss = 0.0 if not hparams.full_latent_tower: d_pred = tf.reshape(tf.maximum(tf.stop_gradient(d), 0), [batch_size, hparams.bottleneck_bits // 8, 8]) d_int = discretization.bit_to_int(d_pred, 8) tf.summary.histogram("d_int", tf.reshape(d_int, [-1])) d_hot = tf.one_hot(d_int, 256, axis=-1) d_pred = discrete_embed(d_hot) layer_pred = tf.reshape(layer, [batch_size, prod(layer_shape[1:])]) prediction0 = tf.layers.dense(layer_pred, state_size, name="istate") c_state = tf.layers.dense(layer_pred, state_size, name="cstate") m_state = tf.layers.dense(layer_pred, state_size, name="mstate") pred = tf.concat([tf.expand_dims(prediction0, axis=1), d_pred], axis=1) state = (c_state, m_state) outputs = [] for i in range(hparams.bottleneck_bits // 8): output, state = lstm_cell(pred[:, i, :], state) outputs.append(tf.expand_dims(output, axis=1)) outputs = tf.concat(outputs, axis=1) d_int_pred = discrete_predict(outputs) pred_loss = tf.losses.sparse_softmax_cross_entropy( logits=d_int_pred, labels=d_int) pred_loss = tf.reduce_mean(pred_loss) if hparams.mode == tf.estimator.ModeKeys.TRAIN: x += tf.truncated_normal(common_layers.shape_list(x), mean=0.0, stddev=0.2) x = tf.tanh(x) noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0 x *= noise d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) p = common_layers.inverse_lin_decay(hparams.discrete_warmup_steps) d = tf.where(tf.less(tf.random_uniform([batch_size]), p), d, x) return add_d(layer, d), pred_loss
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) action = actions[-1] activation_fn = common_layers.belu if self.hparams.activation_fn == "relu": activation_fn = tf.nn.relu # Normalize frames. frames = [common_layers.standardize_images(f) for f in frames] # Stack the inputs. if internal_states is not None and hparams.concat_internal_states: # Use the first part of the first internal state if asked to concatenate. batch_size = common_layers.shape_list(frames[0])[0] internal_state = internal_states[0][0][:batch_size, :, :, :] stacked_frames = tf.concat(frames + [internal_state], axis=-1) else: stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Update internal states early if requested. if hparams.concat_internal_states: internal_states = self.update_internal_states_early( internal_states, frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) if self.has_actions: with tf.variable_scope("policy"): x_flat = tf.layers.flatten(x) policy_pred = tf.layers.dense(x_flat, self.hparams.problem.num_actions) value_pred = tf.layers.dense(x_flat, 1) value_pred = tf.squeeze(value_pred, axis=-1) else: policy_pred, value_pred = None, None # Add embedded action if present. if self.has_actions: x = common_video.inject_additional_input(x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. norm_target_frame = common_layers.standardize_images(target_frame) x, extra_loss = self.inject_latent(x, frames, norm_target_frame, action) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=activation_fn, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if hparams.do_autoregressive_rnn: # If enabled, we predict the target frame autoregregressively using rnns. # To this end, the current prediciton is flattened into one long sequence # of sub-pixels, and so is the target frame. Each sub-pixel (RGB value, # from 0 to 255) is predicted with an RNN. To avoid doing as many steps # as width * height * channels, we only use a number of pixels back, # as many as hparams.autoregressive_rnn_lookback. with tf.variable_scope("autoregressive_rnn"): batch_size = common_layers.shape_list(frames[0])[0] # Height, width, channels and lookback are the constants we need. h, w = inputs_shape[1], inputs_shape[ 2] # 105, 80 on Atari games c = hparams.problem.num_channels lookback = hparams.autoregressive_rnn_lookback assert ( h * w ) % lookback == 0, "Number of pixels must divide lookback." m = (h * w) // lookback # Batch size multiplier for the RNN. # These are logits that will be used as inputs to the RNN. rnn_inputs = tf.layers.dense(x, c * 64, name="rnn_inputs") # They are of shape [batch_size, h, w, c, 64], reshaping now. rnn_inputs = tf.reshape(rnn_inputs, [batch_size * m, lookback * c, 64]) # Same for the target frame. rnn_target = tf.reshape(target_frame, [batch_size * m, lookback * c]) # Construct rnn starting state: flatten rnn_inputs, apply a relu layer. rnn_start_state = tf.nn.relu( tf.layers.dense(tf.nn.relu(tf.layers.flatten(rnn_inputs)), 256, name="rnn_start_state")) # Our RNN function API is on bits, each subpixel has 8 bits. total_num_bits = lookback * c * 8 # We need to provide RNN targets as bits (due to the API). rnn_target_bits = discretization.int_to_bit(rnn_target, 8) rnn_target_bits = tf.reshape(rnn_target_bits, [batch_size * m, total_num_bits]) if self.is_training: # Run the RNN in training mode, add it's loss to the losses. rnn_predict, rnn_loss = discretization.predict_bits_with_lstm( rnn_start_state, 128, total_num_bits, target_bits=rnn_target_bits, extra_inputs=rnn_inputs) extra_loss += rnn_loss # We still use non-RNN predictions too in order to guide the network. x = tf.layers.dense(x, c * 256, name="logits") x = tf.reshape(x, [batch_size, h, w, c, 256]) rnn_predict = tf.reshape(rnn_predict, [batch_size, h, w, c, 256]) # Mix non-RNN and RNN predictions so that after warmup the RNN is 90%. x = tf.reshape(tf.nn.log_softmax(x), [batch_size, h, w, c * 256]) rnn_predict = tf.nn.log_softmax(rnn_predict) rnn_predict = tf.reshape(rnn_predict, [batch_size, h, w, c * 256]) alpha = 0.9 * common_layers.inverse_lin_decay( hparams.autoregressive_rnn_warmup_steps) x = alpha * rnn_predict + (1.0 - alpha) * x else: # In prediction mode, run the RNN without any targets. bits, _ = discretization.predict_bits_with_lstm( rnn_start_state, 128, total_num_bits, extra_inputs=rnn_inputs, temperature=0.0 ) # No sampling from this RNN, just greedy. # The output is in bits, get back the predicted pixels. bits = tf.reshape(bits, [batch_size * m, lookback * c, 8]) ints = discretization.bit_to_int(tf.maximum(bits, 0), 8) ints = tf.reshape(ints, [batch_size, h, w, c]) x = tf.reshape(tf.one_hot(ints, 256), [batch_size, h, w, c * 256]) elif self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") reward_pred = None if self.has_rewards: # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu( tf.layers.dense(reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
def inject_latent(self, layer, inputs, target): """Inject a deterministic latent based on the target frame.""" hparams = self.hparams final_filters = common_layers.shape_list(layer)[-1] filters = hparams.hidden_size kernel = (4, 4) layer_shape = common_layers.shape_list(layer) def add_bits(layer, bits): z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul") if not hparams.complex_addn: return layer + z_mul layer *= tf.nn.sigmoid(z_mul) z_add = tfl.dense(bits, final_filters, name="unbottleneck_add") layer += z_add return layer if not self.is_training: if hparams.full_latent_tower: rand = tf.random_uniform(layer_shape[:-1] + [hparams.bottleneck_bits]) bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 else: bits, _ = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, temperature=hparams.latent_predictor_temperature) bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2) return add_bits(layer, bits), 0.0 # Embed. frames = tf.concat(inputs + [target], axis=-1) x = tfl.dense( frames, filters, name="latent_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) if hparams.full_latent_tower: for i in range(hparams.num_compress_steps): with tf.variable_scope("latent_downstride%d" % i): x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tfl.conv2d(x, filters, kernel, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) else: x = common_layers.double_discriminator(x) x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) bits, bits_clean = discretization.tanh_discrete_bottleneck( x, hparams.bottleneck_bits, hparams.bottleneck_noise, hparams.discretize_warmup_steps, hparams.mode) if not hparams.full_latent_tower: _, pred_loss = discretization.predict_bits_with_lstm( layer, hparams.latent_predictor_state_size, hparams.bottleneck_bits, target_bits=bits_clean) return add_bits(layer, bits), pred_loss
def network(self): def middle_network(layer): # Run a stack of convolutions. x = layer kernel1 = (3, 3) filters = common_layers.shape_list(x)[-1] for i in range(2): with tf.variable_scope("layer%d" % i): y = tf.nn.dropout(x, 1.0 - 0.5) y = tf.layers.conv2d(y, filters, kernel1, activation=self.activation_fn, strides=(1, 1), padding="SAME") if i == 0: x = y else: x = common_layers.layer_norm(x + y) return x batch_size = tf.shape(self.states_ph)[0] filters = self.hidden_size kernel2 = (4, 4) action = self.actions_oph #[0] NOTE - might remove this # Normalize states if (self.n_envs > 1): states = [ common_layers.standardize_images(self.states_ph[i, :, :, :]) for i in range(self.n_envs) ] stacked_states = tf.stack(states) else: stacked_states = common_layers.standardize_images(self.states_ph) inputs_shape = common_layers.shape_list(stacked_states) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_states, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(self.layers): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.dropout_p) x = common_layers.make_even_size(x) if i < 2: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=self.activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) if self.is_policy: with tf.variable_scope("policy"): x_flat = tf.layers.flatten(x) policy_pred = tf.layers.dense(x_flat, self.action_dim) value_pred = tf.layers.dense(x_flat, 1) value_pred = tf.squeeze(value_pred, axis=-1) else: policy_pred, value_pred = None, None #if self.has_actions: x = inject_additional_input(x, action, "action_enc", "multi_additive") # Inject latent if present. Only for stochastic models. target_states = common_layers.standardize_images(self.target_states) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x = middle_network(x) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(self.layers): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - 0.1) if i >= self.layers - 2: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=self.activation_fn, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x = tf.layers.dense(x, self.depth, name="logits") reward_pred = None if self.has_rewards: # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu( tf.layers.dense(reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, policy_pred, value_pred
def decoder(self, x, encoder_layers=None): with tf.variable_scope("decoder"): hparams = self.hparams is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Up-convolutions. for i in range(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 if is_training: nomix_p = common_layers.inverse_lin_decay( int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_%d" % j, nomix_p) filters = hparams.hidden_size * 2**j filters = min(filters, hparams.max_hidden_size) with tf.variable_scope("layer_%d" % i): j = hparams.num_hidden_layers - i - 1 x = tf.layers.conv2d_transpose( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout) x = common_layers.layer_norm(x, name="ln") x = common_attention.add_timing_signal_nd(x) if encoder_layers is not None: enc_x = encoder_layers[j] enc_shape = common_layers.shape_list(enc_x) x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :] if is_training: # Mix at the beginning of training. rand = tf.random_uniform(common_layers.shape_list(x_mix)) x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x) if hparams.gan_loss_factor != 0: x_gan = x[enc_shape[0]:, :enc_shape[1], :enc_shape[2], :] x = tf.concat([x_mix, x_gan], axis=0) else: x = x_mix return x
def sequence_encoder(inputs, length, is_training, cfg): """Encode a sequence using self attention, convolutions, and dense layers. Args: inputs: [batch x length x depth] tensor to encode length: [batch] tensor containing length of each sequence as an int is_training: bool indicating whether we are training cfg: Layer configuration Returns: Encoded sequence Raises: ValueError: If cfg.structure is invalid. """ cfg = utils.Config(cfg) assert length is not None assert is_training in [False, True] # Turn off dropout at test time. if not is_training: for k in cfg: if 'dropout' in k: cfg[k] = 0.0 # Mask out padding tokens during attention maxlen = None if is_training: # All dimensions must be static on a TPU maxlen = inputs.shape.as_list()[1] _, attention_bias = get_attention_bias(length, maxlen=maxlen) if inputs.shape.as_list()[-1] != cfg.hidden_size: # Project to internal size inputs = common_layers.conv1d(inputs=inputs, filters=cfg.hidden_size, kernel_size=1, activation=None, padding='SAME') net = inputs if cfg.timing_signal: net = common_attention.add_timing_signal_nd(net) structure = cfg.structure.split(',') * cfg.layers for layer_id, layer_type in enumerate(structure): with tf.variable_scope('%s_%d' % (layer_type, layer_id)): layer_input = net net = common_layers.layer_norm(net) if layer_type == 'att': net = common_attention.multihead_attention( query_antecedent=net, memory_antecedent=None, bias=attention_bias, total_key_depth=cfg.hidden_size, total_value_depth=cfg.hidden_size, output_depth=cfg.hidden_size, num_heads=cfg.attention_heads, dropout_rate=cfg.attention_dropout, attention_type=cfg.attention_type, make_image_summary=False) elif layer_type == 'conv': if cfg.separable_conv: net = separable_conv(net, filters=cfg.hidden_size, kernel_size=cfg.kernel_size, activation=tf.nn.relu) else: net = common_layers.conv1d(inputs=net, filters=cfg.hidden_size, kernel_size=cfg.kernel_size, activation=tf.nn.relu, padding='SAME') elif layer_type == 'ffn': # TODO(ddohan): See how expert_utils used to do the dense layer net = tf.layers.dense(net, units=int(cfg.ffn_multiplier * cfg.hidden_size), activation=tf.nn.relu) net = tf.layers.dense(net, units=cfg.hidden_size, activation=None) else: raise ValueError('Unknown layer type %s' % layer_type) if cfg.layer_dropout: net = tf.nn.dropout(net, keep_prob=1.0 - cfg.layer_dropout) net += layer_input net = common_layers.layer_norm(net) return net