def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention"): # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. encoder_outputs, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder with attention shifted_targets = common_layers.shift_left(targets) decoder_outputs, _ = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs) return tf.expand_dims(decoder_outputs, axis=2)
def lstm_seq2seq_internal(inputs, targets, hparams, train): """The basic LSTM seq2seq model, main step used for training.""" with tf.variable_scope("lstm_seq2seq"): # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. _, final_encoder_state = lstm(tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder. shifted_targets = common_layers.shift_left(targets) decoder_outputs, _ = lstm(common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", initial_state=final_encoder_state) return tf.expand_dims(decoder_outputs, axis=2)
def slicenet_internal(inputs, targets, target_space, problem_idx, hparams): """The slicenet model, main step used for training.""" with tf.variable_scope("slicenet"): # Flatten inputs and encode. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs_mask = 1.0 - embedding_to_padding(inputs) inputs = common_layers.add_timing_signal(inputs) # Add position info. target_space_emb = embed_target_space(target_space, hparams.hidden_size) extra_layers = int(hparams.num_hidden_layers * 1.5) inputs_encoded = multi_conv_res(inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask) target_modality_name = hparams.problems[ problem_idx].target_modality.name if "class_label_modality" in target_modality_name: # If we're just predicing a class, there is no use for a decoder. return inputs_encoded # Do the middle part. decoder_start, similarity_loss = slicenet_middle( inputs_encoded, targets, target_space_emb, inputs_mask, hparams) # Decode. decoder_final = multi_conv_res(decoder_start, "LEFT", "decoder", hparams.num_hidden_layers, hparams, mask=inputs_mask, source=inputs_encoded) return decoder_final, tf.reduce_mean(similarity_loss)
def bytenet_internal(inputs, targets, hparams): """ByteNet, main step used for training.""" with tf.variable_scope("bytenet"): # Flatten inputs and extend length by 50%. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1])) inputs_shape = inputs.shape.as_list() inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]]) inputs_shape[1] = None inputs.set_shape( inputs_shape) # Don't lose the other shapes when padding. # Pad inputs and targets to be the same length, divisible by 50. inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=50) final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) shifted_targets = common_layers.shift_left(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), hparams.hidden_size, [((1, 1), kernel)], padding="LEFT") return residual_dilated_conv(decoder_start, hparams.num_block_repeat, "LEFT", "decoder", hparams)
def model_fn_body(self, features): hparams = copy.copy(self._hparams) inputs = features.get("inputs") encoder_output, encoder_decoder_attention_bias = self.encode( inputs, hparams) targets_l2r = features["targets_l2r"] targets_r2l = features["targets_r2l"] targets_l2r = common_layers.flatten4d3d(targets_l2r) targets_r2l = common_layers.flatten4d3d(targets_r2l) (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder( targets_l2r, targets_r2l, hparams) decode_output = self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams) return decode_output
def encode(self, inputs, hparams): inputs = common_layers.flatten4d3d(inputs) (encoder_input, self_attention_bias, encoder_decoder_attention_bias) = \ transformer_prepare_encoder(inputs, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) encoder_output = transformer_encoder(encoder_input, self_attention_bias, hparams) return encoder_output, encoder_decoder_attention_bias
def targets_bottom(self, inputs): with tf.variable_scope(self.name): # Reshape inputs to 2-d tensor and embed the RGB pixel values. inputs = common_layers.flatten4d3d(inputs) ret = common_layers.embedding(inputs, self.top_dimensionality, self._body_input_depth, name="input_rgb_embedding") if self._model_hparams.multiply_embedding_mode == "sqrt_depth": ret *= self._body_input_depth**0.5 return ret
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" norm_fn = get_norm(hparams) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Calculate similarity loss (but don't run if not needed). if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001: targets_timed = common_layers.add_timing_signal(targets_flat) extra_layers = int(hparams.num_hidden_layers * 1.5) with tf.variable_scope(tf.get_variable_scope(), reuse=True): targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder", extra_layers, hparams) with tf.variable_scope("similarity_loss"): similarity_loss = similarity_cost(inputs_encoded, targets_encoded) similarity_loss *= hparams.sim_loss_mult else: similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_left(targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.hidden_size, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, similarity_loss
def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = local_features["_shard_features"]({ "targets": targets })["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets
def flatten(inputs): return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)