def model_fn_body(self, features): hparams = self._hparams targets = features["targets"] inputs = features["inputs"] target_space = features["target_space_id"] inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = ( transformer.transformer_prepare_encoder(inputs, target_space, hparams)) (decoder_input, decoder_self_attention_bias ) = transformer.transformer_prepare_decoder(targets, hparams) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_input = tf.nn.dropout( decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_revnet_encoder( encoder_input, encoder_self_attention_bias, hparams) decoder_output = transformer_revnet_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output
def ae_transformer_internal(inputs, targets, target_space, hparams): """AE Transformer, main step used for training.""" with tf.variable_scope("ae_transformer"): # Prepare inputs, targets, k. k = 2**hparams.num_compress_steps _, targets = common_layers.pad_to_same_length( targets, targets, final_length_divisible_by=k) inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") # Compress and ae. ae, hot, kl = ae_compress(targets, hparams.is_2d, hparams, "ae") tf.summary.histogram("hot", tf.reshape(tf.argmax(hot, axis=-1), [-1])) emb = ae_embed(hot, hparams, "ae", reuse=True) # Compress context and run autoregressive decoder on emb-hot. emb_flat = tf.expand_dims(common_layers.flatten4d3d(emb), axis=2) emb_flat = tf.stop_gradient(emb_flat) dec_c = decode(None, None, emb_flat, inputs, ed, hparams) dec_c = tf.reshape(dec_c, tf.shape(emb)) c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(labels=hot, logits=c_z) # If not training, use the predicted z instead of the autoregressive one. if hparams.mode == tf.estimator.ModeKeys.PREDICT: hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) # Decompress, pass for ae loss. z = ae_decompress(emb, ae, targets, hparams.is_2d, hparams, "ae") kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.8), min_value=0.0001) reconstruct_loss *= common_layers.inverse_exp_decay( hparams.startup_steps) losses = {"kl": kl, "reconstruction": reconstruct_loss * 0.1} return z, losses
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention"): # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. encoder_outputs, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder with attention shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs) return tf.expand_dims(decoder_outputs, axis=2)
def bytenet_internal(inputs, targets, hparams): """ByteNet, main step used for training.""" with tf.variable_scope("bytenet"): # Flatten inputs and extend length by 50%. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1])) inputs_shape = inputs.shape.as_list() inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]]) inputs_shape[1] = None inputs.set_shape( inputs_shape) # Don't lose the other shapes when padding. # Pad inputs and targets to be the same length, divisible by 50. inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=50) final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) shifted_targets = common_layers.shift_right(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), hparams.hidden_size, [((1, 1), kernel)], padding="LEFT") return residual_dilated_conv(decoder_start, hparams.num_block_repeat, "LEFT", "decoder", hparams)
def encode(self, inputs, target_space, hparams): """Encode transformer inputs. Args: inputs: Transformer inputs [batch_size, input_length, hidden_dim] target_space: scalar, target space ID. hparams: hyperparmeters for model. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encodre-decoder attention. [batch_size, input_length] """ inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder(encoder_input, self_attention_bias, hparams) return encoder_output, encoder_decoder_attention_bias
def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets
def model_fn_body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "tragets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams) targets = features["targets"] targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams) return self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams)
def slicenet_internal(inputs, targets, target_space, problem_idx, hparams): """The slicenet model, main step used for training.""" with tf.variable_scope("slicenet"): # Flatten inputs and encode. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs_mask = 1.0 - embedding_to_padding(inputs) inputs = common_layers.add_timing_signal(inputs) # Add position info. target_space_emb = embed_target_space(target_space, hparams.hidden_size) extra_layers = int(hparams.num_hidden_layers * 1.5) inputs_encoded = multi_conv_res(inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask) target_modality_name = hparams.problems[ problem_idx].target_modality.name if "class_label_modality" in target_modality_name: # If we're just predicing a class, there is no use for a decoder. return inputs_encoded # Do the middle part. decoder_start, similarity_loss = slicenet_middle( inputs_encoded, targets, target_space_emb, inputs_mask, hparams) # Decode. decoder_final = multi_conv_res(decoder_start, "LEFT", "decoder", hparams.num_hidden_layers, hparams, mask=inputs_mask, source=inputs_encoded) return decoder_final, tf.reduce_mean(similarity_loss)
def lstm_seq2seq_internal(inputs, targets, hparams, train): """The basic LSTM seq2seq model, main step used for training.""" with tf.variable_scope("lstm_seq2seq"): # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. _, final_encoder_state = lstm(tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder. shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm(common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", initial_state=final_encoder_state) return tf.expand_dims(decoder_outputs, axis=2)
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Calculate similarity loss (but don't run if not needed). if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001: targets_timed = common_layers.add_timing_signal(targets_flat) extra_layers = int(hparams.num_hidden_layers * 1.5) with tf.variable_scope(tf.get_variable_scope(), reuse=True): targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder", extra_layers, hparams) with tf.variable_scope("similarity_loss"): similarity_loss = similarity_cost(inputs_encoded, targets_encoded) similarity_loss *= hparams.sim_loss_mult else: similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_right(targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.hidden_size, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, similarity_loss
def model_fn_body(self, features): hparams = self._hparams targets = features["targets"] inputs = features.get("inputs") target_space = features.get("target_space_id") inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) (encoder_input, encoder_attention_bias, _) = (transformer.transformer_prepare_encoder(inputs, target_space, hparams)) (decoder_input, _) = (transformer.transformer_prepare_decoder(targets, hparams)) encoder_mask = bias_to_mask(encoder_attention_bias) def residual_fn(x, y): return common_layers.layer_norm( x + tf.nn.dropout(y, 1.0 - hparams.residual_dropout)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.residual_dropout) encoder_output = alt_transformer_encoder(encoder_input, residual_fn, encoder_mask, hparams) decoder_output = alt_transformer_decoder(decoder_input, encoder_output, residual_fn, encoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output
def targets_bottom(self, inputs): with tf.variable_scope(self.name): # Reshape inputs to 2-d tensor and embed the RGB pixel values. shape = tf.shape(inputs) inputs = common_layers.flatten4d3d(inputs) ret = common_layers.embedding(tf.to_int32(inputs), self.top_dimensionality, self._body_input_depth, name="input_rgb_embedding") if self._model_hparams.multiply_embedding_mode == "sqrt_depth": ret *= self._body_input_depth**0.5 ret = tf.reshape( ret, [shape[0], shape[1], shape[2], self._body_input_depth * 3]) return tf.layers.dense(ret, self._body_input_depth)
def model_fn_body(self, features): hparams = self._hparams targets = features["targets"] targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder( targets, hparams) decoder_input = tf.nn.dropout( decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer_decoder(decoder_input, None, decoder_self_attention_bias, None, hparams) decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output
def model_fn_body(self, features): inputs = features["inputs"] inputs.get_shape().assert_has_rank(4) hp = self._hparams out = inputs out = common_layers.flatten4d3d(out) # Conv layers assert hp.num_conv_layers == len(hp.pooling_windows) for i in xrange(hp.num_conv_layers): out = conv_layer( out, hp.hidden_size, hp.kernel_width, hp.stride, hp.pooling_windows[i], hp.dropout, dilation_rate=1, name="conv_%d" % (i + 1)) # Dense dilated conv layers for i in xrange(hp.num_dconv_layers): dilation_rate = 2**(i + 1) dconv_out = conv_layer( out, hp.hidden_size, hp.kernel_width, stride=1, pooling_window=0, dropout_rate=hp.dropout, dilation_rate=dilation_rate, name="dconv_%d" % (i + 1)) out = tf.concat([out, dconv_out], axis=2) # Fully connected layer out = fc_layer(out, hp.hidden_size, hp.dropout, name="fc") out.get_shape().assert_has_rank(3) out = tf.expand_dims(out, 2) return out
def flatten(inputs): return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)