def transformer_prepare_encoder(inputs, hparams): """Prepare one shard of the model for the encoder. Args: inputs: [batch_size, input_length, hidden_dim] hparams: hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack [batch_size, input_length, hidden_dim] encoder_self_attention_bias: a bias tensor for use in encoder self-attention [batch_size, input_length] top_layer_attention_bias: a bias tensor for use in top layer classification [batch_size, input_length] """ ishape_static = inputs.shape.as_list() encoder_input = inputs encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding top_layer_attention_bias = ignore_padding if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(inputs)[1]) if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, top_layer_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def prepare_image_question_encoder(image_feat, question, hparams): """Prepare encoder. Args: image_feat: a Tensor. question: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention """ encoder_input = tf.concat([image_feat, question], axis=1) encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding # Usual case - not a packed dataset. if hparams.pos == "timing": question = common_attention.add_timing_signal_1d(question) elif hparams.pos == "emb": question = common_attention.add_positional_embedding( question, hparams.max_length, "inputs_positional_embedding", None) encoder_input = tf.concat([image_feat, question], axis=1) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def prepare_image_question_encoder(image_feat, question, hparams): """Prepare encoder. Args: image_feat: a Tensor. question: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention """ encoder_input = tf.concat([image_feat, question], axis=1) encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding # Usual case - not a packed dataset. if hparams.pos == "timing": question = common_attention.add_timing_signal_1d(question) elif hparams.pos == "emb": question = common_attention.add_positional_embedding( question, hparams.max_length, "inputs_positional_embedding", None) encoder_input = tf.concat([image_feat, question], axis=1) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_encoder2(encoder_input, target_space, hparams, emb_name): '''the same as the existing module except for being able to name the embedding''' # compute bias ishape_static = encoder_input.shape.as_list() encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(encoder_input)[1]) # Append target_space_id embedding to encoder_input id_values = [ value for attr, value in vars(problem.SpaceID).items() if not attr.startswith("__") ] id_cur = int(max(id_values) + 1) emb_target_space = common_layers.embedding(target_space, id_cur, ishape_static[-1], name=emb_name) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space # position embedding if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) return encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias
def sample_p(self, targets_length, temp, check_invertibility=False, targets_mask=None, **kwargs): hparams = self._hparams if targets_mask is None: targets_mask = ops.sequence_mask(targets_length, hparams) decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - targets_mask)) batch_size, targets_max_length = ( common_layers.shape_list(targets_mask)[:2]) prior_shape = [batch_size, targets_max_length, hparams.latent_size] noise = tf.random.normal(prior_shape, stddev=temp) p_dist = None if hparams.prior_type == "standard_normal": z_p = noise elif hparams.prior_type == "diagonal_normal": diag_prior_params = ops.cond_prior("diag_prior", hparams, tf.zeros(prior_shape), targets_mask, hparams.latent_size * 2, decoder_self_attention_bias, **kwargs) p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior") z_p = p_dist.loc + p_dist.scale * noise elif hparams.prior_type in ["affine", "additive", "rq"]: n_levels = len(hparams.depths.split("/")) divi = max(1, hparams.factor**(n_levels - 1)) flow_prior_shape = [ batch_size, targets_max_length // divi, hparams.latent_size ] noise = tf.random_normal(flow_prior_shape, stddev=temp) z_p, _, _, _ = glow.glow("glow", noise, targets_mask, decoder_self_attention_bias, inverse=True, init=False, hparams=self._fparams, disable_dropout=True, temp=temp, **kwargs) if self.is_evaluating and check_invertibility: noise_inv, _, _, _ = glow.glow("glow", z_p, targets_mask, decoder_self_attention_bias, inverse=False, init=False, hparams=self._fparams, disable_dropout=True, **kwargs) z_diff = noise - noise_inv tf.summary.scalar("flow_recon_inverse", tf.reduce_max(tf.abs(z_diff))) return z_p, p_dist
def get_attention_bias(sequence_length): """Create attention bias so attention is not applied at padding position.""" # attention_bias: [batch, 1, 1, memory_length] invert_sequence_mask = tf.to_float(tf.logical_not(tf.sequence_mask( sequence_length))) attention_bias = common_attention.attention_bias_ignore_padding( invert_sequence_mask) return attention_bias
def transformer_prepare_decoder_right(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in decoder self-attention """ if hparams.causal_decoder_self_attention: # Causal attention. if hparams.prepend_mode == "prepend_inputs_full_attention": decoder_self_attention_bias = ( common_attention.attention_bias_prepend_inputs_full_attention( common_attention.embedding_to_padding(targets))) else: decoder_self_attention_bias = ( common_attention.attention_bias_local( common_layers.shape_list(targets)[1], 0, -1)) else: # Full attention. decoder_padding = common_attention.embedding_to_padding(targets) decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(decoder_padding)) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = shift_left_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d( decoder_input) elif hparams.pos == "emb": decoder_input = common_attention.add_positional_embedding( decoder_input, hparams.max_length, "targets_positional_embedding", targets_position) if hparams.activation_dtype == "bfloat16": decoder_self_attention_bias = tf.cast(decoder_self_attention_bias, tf.bfloat16) return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding(target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d( encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def get_ignore_padding(inputs): """ Args: inputs: Tensor with shape [batch, memory_length, depth] """ # Extract which individual embedding vectors are identically zero. # encoder_padding has shape [batch, memory_length]. padding = comm_attn.embedding_to_padding(inputs) # ignore_padding has shape [batch, 1, 1, memory_length]. # it also replaces all 1s in encoder_padding with -1e9 because idk. ignore_padding = comm_attn.attention_bias_ignore_padding(padding) return ignore_padding
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_encoder(inputs_emb_var, inputs, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs_emb_var: a Tensor inputs: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ encoder_input = tf.gather(inputs_emb_var, inputs) if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "positional_embedding", inputs_position) if hparams.activation_dtype == "bfloat16": encoder_self_attention_bias = tf.cast(encoder_self_attention_bias, tf.bfloat16) encoder_decoder_attention_bias = tf.cast( encoder_decoder_attention_bias, tf.bfloat16) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def model_fn_body(self, features): hparams = self._hparams inputs = features.get("inputs") firstP = features.get("firstP") firstP = common_layers.flatten4d3d(firstP) targets = features["targets"] targets = common_layers.flatten4d3d(targets) #JI: set image dimensions imageP = features.get("imageP") imageP.set_shape([None, 1, 19600]) imageP=tf.reshape(imageP,[-1, img_dim, 100]) encoder_output, encoder_decoder_attention_bias = (None, None) if inputs is not None: target_space = features["target_space_id"] #JI: if needed pass images to encoder encoder_output, encoder_decoder_attention_bias = self.encode(inputs, target_space, hparams, imageP=None) # used to extract hidden states (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder(firstP, hparams) # the conventional `targets` used for the second-pass decoder, i.e., delib-decoder (delibdecoder_input, delibdecoder_self_attention_bias) = transformer_prepare_decoder(targets, hparams) # the `delibctx` used for the second-pass decoder firstP_input, firstP_self_attention_bias = self.transformer_prepare_delibdecoder(firstP, hparams) # add dropout to the two decoders decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) delibdecoder_input = tf.nn.dropout(delibdecoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer_decoder(decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=None) firstP_input = tf.concat(values=[firstP_input, decoder_output], axis=-1) #JI: get biases for image attention img_encoder_padding = common_attention.embedding_to_padding(imageP) imageP_self_attention_bias = common_attention.attention_bias_ignore_padding(img_encoder_padding) #JI: pass images to the decoder delibdecoder_output = transformer_delibdecoder( delibdecoder_input, encoder_output, firstP_input, imageP, delibdecoder_self_attention_bias, encoder_decoder_attention_bias, firstP_self_attention_bias, imageP_self_attention_bias, hparams, cache=None, name="delib_decoder") return delibdecoder_output
def transformer_prepare_delibdecoder(self, inputs, hparams): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. hparams: run hyperparameters Returns: """ firstPdecoder_input = inputs firstPdecoder_padding = common_attention.embedding_to_padding(firstPdecoder_input) ignore_padding = common_attention.attention_bias_ignore_padding(firstPdecoder_padding) firstP_delib_attention_bias = ignore_padding if hparams.pos == "timing": firstPdecoder_input = common_attention.add_timing_signal_1d(firstPdecoder_input) return (firstPdecoder_input, firstP_delib_attention_bias)
def encode(self, inputs, hparams, features=None): train = hparams.mode == tf.estimator.ModeKeys.TRAIN inputs_length = common_layers.length_from_embedding(inputs) # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) encoder_padding = common_attention.embedding_to_padding(inputs) encoder_decoder_attention_bias = common_attention.attention_bias_ignore_padding( encoder_padding) # LSTM encoder. encoder_outputs, final_encoder_state = lstm_bid_encoder( inputs, inputs_length, self._hparams, train, "encoder") return encoder_outputs, final_encoder_state, encoder_decoder_attention_bias, inputs_length
def compute_iw_marginal( self, targets, targets_mask, decoder_self_attention_bias, features, n_samples, reduce_mean=True, **kwargs): hparams = self._hparams z_q, log_q_z, _ = self.sample_q( targets, targets_mask, decoder_self_attention_bias, n_samples=n_samples, temp=1.0, **kwargs) # [K*B, L, C] iw_kwargs = {key: ops.prepare_for_iw(value, n_samples) for ( key, value) in kwargs.items()} iw_targets_mask = ops.prepare_for_iw(targets_mask, n_samples) iw_decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - iw_targets_mask)) iw_features = copy.copy(features) iw_features["targets"] = ops.prepare_for_iw( features["targets"], n_samples) log_p_z_base, log_abs_det = self.compute_prior_log_prob( z_q, iw_targets_mask, iw_decoder_self_attention_bias, check_invertibility=False, **iw_kwargs) log_p_z = log_p_z_base + log_abs_det body_output = ops.decoder( "decoder", z_q, hparams, iw_decoder_self_attention_bias, **iw_kwargs) logits = self.top(body_output, iw_features) numerator, denominator = self.loss_iw(logits, iw_features) numerator = tf.reduce_sum(numerator[..., 0, 0], 1) # [K*B] denominator = tf.reduce_sum(denominator[..., 0, 0], 1) # [K*B] log_p_x = -1 * numerator / denominator log_q_z = gops.reduce_mean_over_l_sum_over_c(log_q_z, iw_targets_mask) log_p_z = log_p_z / tf.reduce_sum(iw_targets_mask, 1) log_p_x, log_q_z, log_p_z = [ops.unprepare_for_iw(ii, n_samples) for ii in [ log_p_x, log_q_z, log_p_z]] log_w_n = log_p_z - log_q_z log_w_n = tf.nn.log_softmax(log_w_n, axis=0) # [K, B] iw_marginal = log_p_x + log_w_n iw_marginal = tf.reduce_logsumexp(iw_marginal, 0) # [B] if reduce_mean: iw_marginal = tf.cast(tf.reduce_mean(iw_marginal, 0), tf.float32) # [1] else: iw_marginal = tf.cast(iw_marginal, tf.float32) # [1] return iw_marginal
def model_fn_body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "tragets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ #JI: set image shapes imageP = features.get("imageP") imageP.set_shape([None, 1, 19600]) imageP=tf.reshape(imageP,[-1, img_dim, 100]) hparams = self._hparams inputs = features.get("inputs") encoder_output, encoder_decoder_attention_bias = (None, None) if inputs is not None: target_space = features["target_space_id"] # JI: send images to encoder if needed encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams,imageP=None) targets = features["targets"] targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams) #JI: compute attention bias for images for decoder img_encoder_padding = common_attention.embedding_to_padding(imageP) imageP_decoder_self_attention_bias = common_attention.attention_bias_ignore_padding(img_encoder_padding) #JI: send images for decoder if needed return self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, imageP=imageP, imageP_decoder_self_attention_bias=imageP_decoder_self_attention_bias)
def forward(self, contexts_emb, contexts, abbr_inp_emb, longform_emb=None): """ :param contexts_emb: [batch_size, context_len, emb_dim] :param contexts: a list of tensors of words, [batch_size] * context_len :param abbr_inp_emb: [batch_size, 1, emb_dim] :param longform_emb: [batch_size, longform_len, emb_dim] :return: decoder_output: predicted abbr embedding, [batch_size, 1, emb_dim] """ saved_weights = {} extra_loss = None contexts_bias = common_attention.attention_bias_ignore_padding( tf.to_float( tf.equal(tf.stack(contexts, axis=1), self.voc.encode(constant.PAD)))) contexts_emb = tf.nn.dropout( contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout) abbr_inp_emb = tf.nn.dropout( abbr_inp_emb, 1.0 - self.hparams.layer_prepostprocess_dropout) # [batch_size, context_len, emb_dim] encoder_output = transformer.transformer_encoder( contexts_emb, contexts_bias, hparams=self.hparams, save_weights_to=saved_weights) # [batch_size, 1, emb_dim] decoder_output = transformer.transformer_decoder( abbr_inp_emb, encoder_output, decoder_self_attention_bias=tf.zeros( [self.model_config.batch_size, 1, 1, 1]), encoder_decoder_attention_bias=contexts_bias, hparams=self.hparams, save_weights_to=saved_weights) return decoder_output, saved_weights, extra_loss
def transformer_prepare_encoder(inputs, target_space, hparams): """Copied from tensor2tensor.models.transformer.""" ishape_static = inputs.shape.as_list() encoder_input = inputs encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding(target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def prepare_question_encoder(inputs, hparams): """Prepare question encoder. Args: inputs: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention """ encoder_input = inputs # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", None) return (encoder_input, encoder_self_attention_bias)
def prepare_question_encoder(inputs, hparams): """Prepare question encoder. Args: inputs: a Tensor. hparams: run hyperparameters Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention """ encoder_input = inputs # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding if hparams.pos == "timing": encoder_input = common_attention.add_timing_signal_1d(encoder_input) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", None) return (encoder_input, encoder_self_attention_bias)
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp( tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack( mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def transformer_fn(self, sentence_complex_input_placeholder, emb_complex, sentence_simple_input_placeholder, emb_simple, w, b, rule_id_input_placeholder, rule_target_input_placeholder, mem_contexts, mem_outputs, global_step, score, comp_features, obj): encoder_mask = tf.to_float( tf.equal(tf.stack(sentence_complex_input_placeholder, axis=1), self.data.vocab_complex.encode(constant.SYMBOL_PAD))) encoder_attn_bias = common_attention.attention_bias_ignore_padding(encoder_mask) obj_tensors = {} train_mode = self.model_config.train_mode if self.model_config.bert_mode: # Leave space for decoder when static seq gpu_id = 0 if train_mode == 'static_seq' or train_mode == 'static_self-critical' or 'direct' in self.model_config.memory else 1 with tf.device('/device:GPU:%s' % gpu_id): sentence_complex_input = tf.stack(sentence_complex_input_placeholder, axis=1) bert_model = BertModel( BertConfig.from_json_file(self.model_config.bert_config), self.is_train, sentence_complex_input, input_mask=1.0-encoder_mask, token_type_ids=None, use_one_hot_embeddings=False) encoder_embed_inputs = bert_model.embedding_output encoder_outputs = bert_model.sequence_output emb_complex = bert_model.embedding_table # update emb complex if (self.model_config.tie_embedding == 'all' or self.model_config.tie_embedding == 'enc_dec'): emb_simple = bert_model.embedding_table if (self.model_config.tie_embedding == 'all' or self.model_config.tie_embedding == 'dec_out'): emb_w_proj = tf.get_variable( 'emb_w_proj', shape=[self.model_config.dimension, self.model_config.dimension], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) w = tf.matmul(bert_model.embedding_table, emb_w_proj) if 'direct' in self.model_config.memory: with tf.device('/device:GPU:1'): direct_mask = tf.to_float( tf.equal(tf.stack(rule_target_input_placeholder, axis=1), self.data.vocab_complex.encode(constant.SYMBOL_PAD))) direct_bert_model = BertModel( BertConfig.from_json_file(self.model_config.bert_config), self.is_train, tf.stack(rule_target_input_placeholder, axis=1), input_mask=1.0 - direct_mask, token_type_ids=None, use_one_hot_embeddings=False, embedding_table=emb_simple, scope='direct') direct_bert_output = direct_bert_model.sequence_output obj_tensors['direct_bert_bias'] = common_attention.attention_bias_ignore_padding(direct_mask) obj_tensors['direct_bert_output'] = direct_bert_output else: encoder_embed_inputs = tf.stack( self.embedding_fn(sentence_complex_input_placeholder, emb_complex), axis=1) if self.hparams.pos == 'timing': encoder_embed_inputs = common_attention.add_timing_signal_1d(encoder_embed_inputs) print('Use positional encoding in encoder text.') if self.model_config.subword_vocab_size and self.model_config.seg_mode: encoder_embed_inputs = common_attention.add_positional_embedding( encoder_embed_inputs, 100, 'seg_embedding', positions=obj['line_comp_segids']) print('Add segment embedding.') with tf.variable_scope('transformer_encoder'): encoder_embed_inputs = tf.nn.dropout(encoder_embed_inputs, 1.0 - self.hparams.layer_prepostprocess_dropout) if self.model_config.architecture == 'ut2t': encoder_outputs, encoder_extra_output = universal_transformer_util.universal_transformer_encoder( encoder_embed_inputs, encoder_attn_bias, self.hparams) enc_ponder_times, enc_remainders = encoder_extra_output extra_encoder_loss = ( self.hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) if self.is_train: obj_tensors['extra_encoder_loss'] = extra_encoder_loss else: encoder_outputs = transformer.transformer_encoder( encoder_embed_inputs, encoder_attn_bias, self.hparams) # Update score based on multiplier score, pred_score_tuple = self.update_score( score, encoder_outputs=encoder_outputs, encoder_mask=tf.to_float( tf.not_equal(tf.stack(sentence_complex_input_placeholder, axis=1), self.data.vocab_complex.encode(constant.SYMBOL_PAD))), comp_features=comp_features) encoder_outputs = self.update_encoder_embedding(encoder_outputs, score) encoder_embed_inputs_list = tf.unstack(encoder_embed_inputs, axis=1) with tf.variable_scope('transformer_decoder', reuse=tf.AUTO_REUSE): if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode: go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)[0] else: go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO) batch_go = tf.tile( tf.expand_dims(self.embedding_fn(go_id, emb_simple), axis=0), [self.model_config.batch_size, 1]) # For static_seq train_mode if self.model_config.npad_mode == 'static_seq': with tf.variable_scope('npad'): npad_w = tf.get_variable( 'npad_w', shape=[1, self.model_config.dimension, self.model_config.dimension], initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) obj_tensors['npad_w'] = npad_w if self.is_train and (train_mode == 'teacher' or train_mode == 'teachercritical'or train_mode == 'teachercriticalv2'): # General train print('Use Generally Process.') decoder_embed_inputs_list = self.embedding_fn( sentence_simple_input_placeholder[:-1], emb_simple) final_output, decoder_output, cur_context = self.decode_step( decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, batch_go, obj_tensors) decoder_logit = ( tf.nn.conv1d(final_output, tf.expand_dims(tf.transpose(w), axis=0), 1, 'SAME') + tf.expand_dims(tf.expand_dims(b, axis=0), axis=0)) decoder_target_list = [] decoder_logit_list = tf.unstack(decoder_logit, axis=1) for logit in decoder_logit_list: decoder_target_list.append(tf.argmax(logit, output_type=tf.int32, axis=-1)) decoder_output_list = [ tf.squeeze(d, 1) for d in tf.split(decoder_output, self.model_config.max_simple_sentence, axis=1)] final_output_list = [ tf.squeeze(d, 1) for d in tf.split(final_output, self.model_config.max_simple_sentence, axis=1)] if self.model_config.pointer_mode: segment_mask = None if 'line_comp_segids' in obj: segment_mask = obj['line_comp_segids'] decoder_logit_list = word_distribution( decoder_logit_list, decoder_output_list, encoder_outputs, encoder_embed_inputs, sentence_complex_input_placeholder, obj_tensors, self.model_config, self.data, segment_mask) elif self.is_train and (train_mode == 'static_seq' or train_mode == 'static_self-critical'): decoder_target_list = [] decoder_logit_list = [] decoder_embed_inputs_list = [] # Will Override for following 3 lists final_output_list = [] decoder_output_list = [] contexts = [] sample_target_list = [] sample_logit_list = [] gpu_assign_interval = int(self.model_config.max_simple_sentence / 3) for step in range(self.model_config.max_simple_sentence): gpu_id = int(step / gpu_assign_interval) if gpu_id > 3: gpu_id = 3 gpu_id += 1 with tf.device('/device:GPU:%s' % gpu_id): print('Step%s with GPU%s' % (step, gpu_id)) final_outputs, _, cur_context = self.decode_step( decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, batch_go, obj_tensors) final_output_list = [ tf.squeeze(d, 1) for d in tf.split(final_outputs, step+1, axis=1)] final_output = final_output_list[-1] # if self.model_config.npad_mode == 'static_seq': # final_output = tf.matmul(final_output, npad_w) last_logit_list = self.output_to_logit(final_output, w, b) last_target_list = tf.argmax(last_logit_list, output_type=tf.int32, axis=-1) decoder_logit_list.append(last_logit_list) decoder_target_list.append(last_target_list) decoder_embed_inputs_list.append( tf.stop_gradient(self.embedding_fn(last_target_list, emb_simple))) if train_mode == 'static_self-critical': last_sample_list = tf.multinomial(last_logit_list, 1) sample_target_list.append(last_sample_list) indices = tf.stack( [tf.range(0, self.model_config.batch_size, dtype=tf.int64), tf.squeeze(last_sample_list)], axis=-1) sample_logit_list.append(tf.gather_nd(tf.nn.softmax(last_logit_list), indices)) else: # Beam Search print('Use Beam Search with Beam Search Size %d.' % self.model_config.beam_search_size) return self.transformer_beam_search(encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list, sentence_complex_input_placeholder, emb_simple, w, b, rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, obj, obj_tensors) gt_target_list = sentence_simple_input_placeholder output = ModelOutput( contexts=cur_context if 'rule' in self.model_config.memory else None, encoder_outputs=encoder_outputs, decoder_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None, final_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None, decoder_logit_list=decoder_logit_list if train_mode != 'dynamic_self-critical' else None, gt_target_list=gt_target_list, encoder_embed_inputs_list=tf.unstack(encoder_embed_inputs, axis=1), decoder_target_list=decoder_target_list, sample_logit_list=sampled_logit_list if train_mode == 'dynamic_self-critical' else None, sample_target_list=sampled_target_list if train_mode == 'dynamic_self-critical' else None, pred_score_tuple=pred_score_tuple if 'pred' in self.model_config.tune_mode else None, obj_tensors=obj_tensors, ) return output
def transformer_prepare_encoder(inputs, target_space, hparams, features=None, type_ids=None, num_types=None, reuse_target_embedding=tf.AUTO_REUSE): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. type_ids: optional, an int64 Tensor of shape [batch, length] that allows for adding type embeddings, similar to positional embeddings. num_types: optional, an int that decides the number of types in type_ids. reuse_target_embedding: option to reuse variable name in the case that symbol modalities are reused between inputs/targets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: encoder_self_attention_bias = ( common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation)) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: # Usual case - not a packed dataset. encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) if target_space is not None and hparams.get("use_target_space_embedding", True): # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", dtype=hparams.get("activation_dtype", "float32"), reuse=reuse_target_embedding) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d( encoder_input) elif hparams.pos == "timing_from_features": encoder_input = common_attention.add_timing_signals_from_features( encoder_input, features, hparams.position_features) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", inputs_position) # Add type embeddings if type_ids is not None: if not num_types: raise ValueError("Need to set num_types as well.") encoder_input = common_attention.add_positional_embedding( encoder_input, num_types, "inputs_type_embedding", type_ids) encoder_self_attention_bias = common_layers.cast_like( encoder_self_attention_bias, encoder_input) encoder_decoder_attention_bias = common_layers.cast_like( encoder_decoder_attention_bias, encoder_input) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp( tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack( mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def create_model(self): with tf.variable_scope('variables'): abstr_ph = [] for _ in range(self.model_config.max_abstr_len): abstr_ph.append(tf.zeros(self.model_config.batch_size, tf.int32, name='abstract_input')) kwords_ph = [] for _ in range(self.model_config.max_cnt_kword): kword = [] for _ in range(self.model_config.max_kword_len): kword.append(tf.zeros(self.model_config.batch_size, tf.int32, name='kword_input')) kwords_ph.append(kword) # Train for length control if self.is_train: kword_occupies_ph = [] for _ in range(self.model_config.max_cnt_kword): kword_occupies_ph.append( tf.zeros(self.model_config.batch_size, tf.float32, name='kword_occupy_input')) emb_abstr, emb_kword, proj_w, proj_b = self.get_embedding() abstr = tf.stack(self.embedding_fn(abstr_ph, emb_abstr), axis=1) kwords = [] for kword_idx in range(self.model_config.max_cnt_kword): kwords.append(self.embedding_fn(kwords_ph[kword_idx], emb_kword)) with tf.variable_scope('model_encoder'): if self.hparams.pos == 'timing': abstr = common_attention.add_timing_signal_1d(abstr) encoder_embed_inputs = tf.nn.dropout(abstr, 1.0 - self.hparams.layer_prepostprocess_dropout) abstr_bias = common_attention.attention_bias_ignore_padding( tf.to_float(tf.equal(tf.stack(abstr_ph, axis=1), self.voc_kword.encode(constant.SYMBOL_PAD)))) abstr_outputs = transformer.transformer_encoder( encoder_embed_inputs, abstr_bias, self.hparams) losses = [] targets = [] pred_occupies = [] obj = {} hist_vector = None if 'kp_attn' in self.model_config.cov_mode: hist_vector = tf.zeros( [self.model_config.batch_size, 1, self.model_config.dimension,]) with tf.variable_scope('model_decoder'): if self.model_config.subword_vocab_size: go_id = self.voc_kword.encode(constant.SYMBOL_GO)[0] else: go_id = self.voc_kword.encode(constant.SYMBOL_GO) batch_go = tf.tile( tf.expand_dims(self.embedding_fn(go_id, emb_kword), axis=0), [self.model_config.batch_size, 1]) for kword_idx in range(self.model_config.max_cnt_kword): if self.is_train: kword = kwords[kword_idx][:-1] kword_ph = kwords_ph[kword_idx] kword_output, kword_output_list = self.decode_step( kword, abstr_outputs, abstr_bias, batch_go, hist_vector=hist_vector) kword_logit_list = [self.output_to_logit(o, proj_w, proj_b) for o in kword_output_list] kword_target_list = [tf.argmax(o, output_type=tf.int32, axis=-1) for o in kword_logit_list] kword_lossbias = [ tf.to_float(tf.not_equal(d, self.voc_kword.encode(constant.SYMBOL_PAD))) for d in kword_ph] kword_lossbias = tf.stack(kword_lossbias, axis=1) if self.model_config.number_samples > 0: loss_fn = tf.nn.sampled_softmax_loss else: loss_fn = None loss = sequence_loss(logits=tf.stack(kword_logit_list, axis=1), targets=tf.stack(kword_ph, axis=1), weights=kword_lossbias, softmax_loss_function=loss_fn, w=proj_w, b=proj_b, decoder_outputs=tf.stack(kword_output_list, axis=1), number_samples=self.model_config.number_samples ) kword_target = tf.stack(kword_target_list, axis=1) targets.append(kword_target) if 'kp_attn' in self.model_config.cov_mode: kword_embed = self.embedding_fn(kword_ph, emb_kword) hist_vector += tf.expand_dims(tf.reduce_mean( tf.stack(kword_embed, axis=1), axis=1), axis=1) # Train for length control pred_occupy = self.get_pred_occupy_logit(hist_vector, abstr_outputs) occupy_loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=pred_occupy, labels=kword_occupies_ph[kword_idx]) loss += tf.reduce_mean(occupy_loss) pred_occupies.append(pred_occupy) losses.append(loss) else: loss, kword_target = self.transformer_beam_search( abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, hist_vector=hist_vector) targets.append(kword_target) losses = loss if 'kp_attn' in self.model_config.cov_mode: kword_embed = self.embedding_fn(kword_target, emb_kword) hist_vector += tf.expand_dims(tf.reduce_mean(kword_embed, axis=1), axis=1) pred_occupy = tf.round(tf.sigmoid(self.get_pred_occupy_logit(hist_vector, abstr_outputs))) pred_occupies.append(pred_occupy) tf.get_variable_scope().reuse_variables() if targets: obj['targets'] = tf.stack(targets, axis=1) obj['abstr_ph'] = abstr_ph obj['kwords_ph'] = kwords_ph if self.is_train: obj['kword_occupies_ph'] = kword_occupies_ph pred_occupies = tf.stack(pred_occupies, axis=1) obj['pred_occupies'] = pred_occupies if type(losses) is list: losses = tf.add_n(losses) return losses, obj
def hierarchical_attention_network_encoder( encoder_input, encoder_self_attention_bias, contexts, context_self_attention_biases, features, hparams, name="hierarchical_attention_network_encoder", save_weights_to=None, make_image_summary=True, losses=None): input_x = encoder_input context_xs = {} for context_name in contexts: context_xs[context_name] = contexts[context_name] context_paddings = {} context_nonpaddings = {} context_pad_removers = {} attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): input_padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) input_nonpadding = 1.0 - input_padding for context_name in context_self_attention_biases: context_paddings[ context_name] = common_attention.attention_bias_to_padding( context_self_attention_biases[context_name]) context_nonpaddings[ context_name] = 1.0 - context_paddings[context_name] input_pad_remover = None for context_name in context_paddings: context_pad_removers[context_name] = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): input_pad_remover = expert_utils.PadRemover(input_padding) for context_name in context_paddings: context_pad_removers[context_name] = expert_utils.PadRemover( context_paddings[context_name]) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> num_hidden_layers - 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", hparams.num_hidden_layers - 1) encoder_output = transformer_with_contexts_layers.transformer_encoder( input_x, encoder_self_attention_bias, temp_hparam, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=save_weights_to, make_image_summary=make_image_summary) context_encoded_outputs = {} for context_name in context_xs: context_encoded_outputs[ context_name] = transformer_with_contexts_layers.transformer_encoder( context_xs[context_name], context_self_attention_biases[context_name], hparams, nonpadding=features_to_nonpadding(features, context_name), save_weights_to=save_weights_to, make_image_summary=make_image_summary) with tf.variable_scope('word_abstraction', reuse=tf.AUTO_REUSE): encoder_word_level_query = common_layers.dense( encoder_output, hparams.hidden_size) # q_w = f_w(h_t) encoder_word_level_abstraction = {} for context_name in context_encoded_outputs: encoder_word_level_abstraction[ context_name] = transformer_with_contexts_layers.multihead_attention( common_layers.layer_preprocess( encoder_word_level_query, hparams), context_encoded_outputs[context_name], context_self_attention_biases[context_name], hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) # s^j, sentence_information = tf.concat([ encoder_word_level_abstraction[context_name] for context_name in encoder_word_level_abstraction ], axis=1) with tf.variable_scope('sentence_abstraction', reuse=tf.AUTO_REUSE): encoder_sentence_level_query = common_layers.dense( encoder_output, hparams.hidden_size) # q_s = f_s(h_t) context_padding = common_attention.embedding_to_padding( sentence_information) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) contextual_information = transformer_with_contexts_layers.multihead_attention( common_layers.layer_preprocess(encoder_sentence_level_query, hparams), sentence_information, ignore_padding, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d") ) # MultiHead(q_s, s^j), [batch, encoder_length, hidden_dim] contextual_information = common_layers.dense_relu_dense( contextual_information, hparams.filter_size, hparams.hidden_size) with tf.variable_scope('context_gating', reuse=tf.AUTO_REUSE): gate_lambda = tf.nn.sigmoid( common_layers.dense(contextual_information, hparams.hidden_size) + common_layers.dense(encoder_output, hparams.hidden_size)) encoder_output = gate_lambda * encoder_output + ( 1 - gate_lambda) * contextual_information return common_layers.layer_preprocess(encoder_output, hparams)
def hierarchical_context_encoder(encoder_input, encoder_self_attention_bias, contexts, context_self_attention_biases, features, hparams, name="discourse_aware_encoder", save_weights_to=None, make_image_summary=True, losses=None): input_x = encoder_input context_xs = {} for context_name in contexts: context_xs[context_name] = contexts[context_name] context_paddings = {} context_nonpaddings = {} context_pad_removers = {} attention_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "attention_dropout_broadcast_dims", ""))) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): input_padding = common_attention.attention_bias_to_padding( encoder_self_attention_bias) input_nonpadding = 1.0 - input_padding for context_name in context_self_attention_biases: context_paddings[ context_name] = common_attention.attention_bias_to_padding( context_self_attention_biases[context_name]) context_nonpaddings[ context_name] = 1.0 - context_paddings[context_name] input_pad_remover = None for context_name in context_paddings: context_pad_removers[context_name] = None if hparams.use_pad_remover and not common_layers.is_xla_compiled(): input_pad_remover = expert_utils.PadRemover(input_padding) for context_name in context_paddings: context_pad_removers[context_name] = expert_utils.PadRemover( context_paddings[context_name]) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> num_hidden_layers - 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", hparams.num_hidden_layers - 1) encoder_output = transformer_with_contexts_layers.transformer_encoder( input_x, encoder_self_attention_bias, temp_hparam, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=save_weights_to, make_image_summary=make_image_summary) context_encoded_outputs = {} for context_name in context_xs: context_encoded_outputs[ context_name] = transformer_with_contexts_layers.transformer_encoder( context_xs[context_name], context_self_attention_biases[context_name], temp_hparam, nonpadding=features_to_nonpadding(features, context_name), save_weights_to=save_weights_to, make_image_summary=make_image_summary) with tf.variable_scope("hierarchical_context_encoder", reuse=tf.AUTO_REUSE): for context_name in context_encoded_outputs: # self attention feed-forward _y = ffn_self_attention_layer( context_encoded_outputs[context_name], hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, save_weights_to=save_weights_to, name="attentive_sum") # mean over sequence length context_encoded_outputs[context_name] = tf.reduce_mean( _y, axis=1, keep_dims=True) encoded_contexts = [ context_encoded_outputs[context_name] for context_name in context_encoded_outputs ] encoded_contexts = tf.concat(encoded_contexts, axis=1) temp_hparam = tf.contrib.training.HParams( ) # copy hparams except num_hidden_layers -> 1 for key, val in hparams.values().items(): temp_hparam.add_hparam(key, val) temp_hparam.set_hparam("num_hidden_layers", 1) context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) encoded_contexts = transformer_encoder(encoded_contexts, ignore_padding, temp_hparam) with tf.variable_scope("encoder/layer_%d" % hparams.num_hidden_layers, reuse=tf.AUTO_REUSE): with tf.variable_scope("context_input_attention"): context_padding = common_attention.embedding_to_padding( encoded_contexts) ignore_padding = common_attention.attention_bias_ignore_padding( context_padding) _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), encoded_contexts, ignore_padding, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, make_image_summary=make_image_summary, max_relative_position=hparams.max_relative_position, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoded_contexts = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("input_self_attention"): _y = common_attention.multihead_attention( common_layers.layer_preprocess(encoder_output, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, make_image_summary=make_image_summary, dropout_broadcast_dims=attention_dropout_broadcast_dims, max_length=hparams.get("max_length"), vars_3d=hparams.get("attention_variables_3d")) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) with tf.variable_scope("gated_sum"): _depth = common_layers.shape_list(encoder_output)[-1] gate = tf.layers.dense(tf.concat( [encoded_contexts, encoder_output], axis=-1), _depth, activation=tf.nn.sigmoid) if save_weights_to: save_weights_to["gated_sum"] = gate encoder_output = gate * encoder_output + ( 1. - gate) * encoded_contexts with tf.variable_scope("ffn"): _y = transformer_ffn_layer(common_layers.layer_preprocess( encoder_output, hparams), hparams, input_pad_remover, conv_padding="SAME", nonpadding_mask=input_nonpadding, losses=losses) encoder_output = common_layers.layer_postprocess( encoder_output, _y, hparams) return common_layers.layer_preprocess(encoder_output, hparams)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment( targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding(target_space, 32, ishape_static[-1], name="target_space_embedding") emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space #if hparams.pos == "timing": # if inputs_position is not None: # encoder_input = common_attention.add_timing_signal_1d_given_position( # encoder_input, inputs_position) # else: # encoder_input = common_attention.add_timing_signal_1d(encoder_input) raw_encoder_input = tf.squeeze(features['inputs_raw'], axis=[-2, -1]) pos_signals = generate_positional_signals(raw_encoder_input, hparams) pos_embeddings = generate_positional_embeddings(pos_signals, hparams.encoder_pos, hparams) if "sum" in hparams.encoder_pos_integration: encoder_input = encoder_input + pos_embeddings elif "ffn" in hparams.encoder_pos_integration: with tf.variable_scope("encoder_pos_ffn"): encoder_input = tf.concat([encoder_input, pos_embeddings], axis=2) encoder_input = transformer_ffn_layer(encoder_input, hparams, conv_padding="SAME") return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: encoder_self_attention_bias = ( common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation)) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment(targets_segmentation, inputs_segmentation)) else: encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) if (hasattr(hparams, "unidirectional_encoder") and hparams.unidirectional_encoder): tf.logging.info("Using unidirectional encoder") encoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(inputs)[1])) else: # Usual case - not a packed dataset. encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) if hparams.get("use_target_space_embedding", True): # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", dtype=tf.bfloat16 if hparams.activation_dtype == "bfloat16" else tf.float32) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d(encoder_input) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", inputs_position) if hparams.activation_dtype == "bfloat16": encoder_self_attention_bias = tf.cast(encoder_self_attention_bias, tf.bfloat16) encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, tf.bfloat16) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def encode_lex(self, encoder_input, target_space, hparams): ''' encoder_input: [batch_size, input_len, hidden_dim] return: encoder_output: [batch_size, input_len, hidden_dim] encoder_decoder_attention_bias: [batch_size, input_len] ''' encoder_output_slices = [] for i in range(encoder_input.get_shape()[2].value): encoder_input_slice = encoder_input[:, :, i, :] # bias encoder_padding = common_attention.embedding_to_padding( encoder_input_slice) print(encoder_padding.shape.as_list() ) # ==> [None, None] (None, None, 4) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding print(ignore_padding.shape.as_list() ) # ==> [None, 1, 1, None] (None, 1, 1, None, 4) # add target space to encoder input? ishape_static = encoder_input_slice.shape.as_list() print(ishape_static) # ==> [None, None, 300] (None, None, 4, 300) emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding") print(emb_target_space.shape.as_list()) # ==> [300] emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) print(emb_target_space.shape.as_list()) # ==> [1, 1, 300] encoder_input_slice += emb_target_space print(encoder_input_slice.shape.as_list() ) # ==> [None, None, 300] (None, None, 4, 300) # add timing signals to encoder input if hparams.pos == "timing": encoder_input_slice = common_attention.add_timing_signal_1d( encoder_input_slice) # dropout encoder_input_slice = tf.nn.dropout( encoder_input_slice, 1.0 - hparams.layer_prepostprocess_dropout) # encoder ''' multihead_attention( query_antecedent: [batch, length_q, channels], -- x, x memory_antecedent: [batch, length_m, channels], -- None, encoder_output bias: bias tensor, -- encoder_self_attention_bias total_key_depth: int, -- hparams.attention_key_channels or hparams.hidden_size total_value_depth: int, -- hparams.attention_value_channels or hparams.hidden_size output_depth: integer, -- hparams.hidden_size num_heads: integer dividing total_key_depth and total_value_depth, -- hparams.num_heads (8) dropout_rate: float, -- hparams.attention_dropout ... cache=None: dict, containing tensors which are the results of previous attentions used for fast decoding, {'k': [batch_size, 0, key_channels], 'v': [batch_size, 0, value_channels], used in decoder self-attention) ''' x = encoder_input_slice with tf.variable_scope("encoder" + str(i)): # remove pad pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover( common_attention.attention_bias_to_padding( encoder_self_attention_bias)) # self-attention along the sentence dimension for layer in xrange(hparams.num_encoder_layers or hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): query_antecedent = common_layers.layer_preprocess( x, hparams) y = common_attention.multihead_attention( query_antecedent=query_antecedent, memory_antecedent=None, bias=encoder_self_attention_bias, total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams. attention_value_channels or hparams.hidden_size, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams. max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) encoder_output_slice = common_layers.layer_preprocess( x, hparams) print(encoder_output_slice.shape.as_list() ) # ==> [None, None, 300] (None, None, 4, 300) encoder_output_slices.append(encoder_output_slice) encoder_output = tf.stack(encoder_output_slices, 2) print(encoder_output.shape.as_list()) # ==> [None, None, 4, 300] # -------- encoder_output_slices = [] #hparams2 = copy.deepcopy(hparams) #hparams2.hidden_size = hparams.lex_cap num_heads = int(hparams.lex_cap / 2) hparams2 = tf.contrib.training.HParams( layer_preprocess_sequence=hparams.layer_preprocess_sequence, layer_postprocess_sequence=hparams.layer_postprocess_sequence, layer_prepostprocess_dropout=hparams.layer_prepostprocess_dropout, norm_type=hparams.norm_type, hidden_size=hparams.lex_cap, norm_epsilon=hparams.norm_epsilon, ffn_layer=hparams.ffn_layer, filter_size=hparams.filter_size, relu_dropout=hparams.relu_dropout, num_heads=num_heads, attention_dropout=hparams.attention_dropout, parameter_attention_key_channels=hparams. parameter_attention_key_channels, parameter_attention_value_channels=hparams. parameter_attention_value_channels) for i in range(encoder_output.get_shape()[3].value): encoder_input_slice = encoder_output[:, :, :, i] #print(encoder_input_slice.shape.as_list()) # ==> [None, None, 4] encoder_padding = common_attention.embedding_to_padding( encoder_input_slice) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding #print(encoder_self_attention_bias.shape.as_list()) # ==> [None, 1, 1, None] # encoder ''' multihead_attention( query_antecedent: [batch, length_q, channels], -- x, x memory_antecedent: [batch, length_m, channels], -- None, encoder_output bias: bias tensor, -- encoder_self_attention_bias total_key_depth: int, -- hparams.attention_key_channels or hparams.hidden_size total_value_depth: int, -- hparams.attention_value_channels or hparams.hidden_size output_depth: integer, -- hparams.hidden_size num_heads: integer dividing total_key_depth and total_value_depth, -- hparams.num_heads (8) dropout_rate: float, -- hparams.attention_dropout ... cache=None: dict, containing tensors which are the results of previous attentions used for fast decoding, {'k': [batch_size, 0, key_channels], 'v': [batch_size, 0, value_channels], used in decoder self-attention) ''' x = encoder_input_slice with tf.variable_scope("encoder_extra" + str(i)): # remove pad pad_remover = None if hparams.use_pad_remover: pad_remover = expert_utils.PadRemover( common_attention.attention_bias_to_padding( encoder_self_attention_bias)) # self-attention along the lexicon dimension with tf.variable_scope("layer_extra"): with tf.variable_scope("self_attention"): #query_antecedent = layer_preprocess2(x, hparams, hparams.lex_cap) query_antecedent = common_layers.layer_preprocess( x, hparams2) y = common_attention.multihead_attention( query_antecedent=query_antecedent, memory_antecedent=None, bias=encoder_self_attention_bias, total_key_depth=hparams.attention_key_channels or hparams.lex_cap, total_value_depth=hparams.attention_value_channels or hparams.lex_cap, output_depth=hparams.lex_cap, num_heads=num_heads, dropout_rate=hparams.attention_dropout, attention_type=hparams.self_attention_type, max_relative_position=hparams.max_relative_position ) #x = layer_postprocess2(x, y, hparams, hparams.lex_cap) x = common_layers.layer_postprocess(x, y, hparams2) with tf.variable_scope("ffn"): y = transformer.transformer_ffn_layer( common_layers.layer_preprocess(x, hparams2), hparams2, pad_remover) #x = layer_postprocess2(x, y, hparams, hparams.lex_cap) x = common_layers.layer_postprocess(x, y, hparams2) #encoder_output_slice = layer_preprocess2(x, hparams, hparams.lex_cap) encoder_output_slice = common_layers.layer_preprocess( x, hparams2) #print(encoder_output_slice.shape.as_list()) # ==> [None, None, 4] (None, None, 4, 300) encoder_output_slices.append(encoder_output_slice) encoder_output = tf.stack(encoder_output_slices, 3) print(encoder_output.shape.as_list()) # ==> [None, None, 4, 300] # -------- lex_cap = encoder_output.get_shape()[2].value embed_len = encoder_output.get_shape()[3].value assert (lex_cap == hparams.lex_cap) aggregate_layer = tf.get_variable( name="Aggregate", shape=[embed_len, embed_len, lex_cap], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1)) encoder_output = tf.tensordot(encoder_output, aggregate_layer, axes=[[2, 3], [1, 2]]) print(encoder_output.shape.as_list()) # ==> [None, None, 300] return encoder_output, encoder_decoder_attention_bias
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ tf.logging.info("Using PgScratch BODY function.") hparams = self._hparams losses = {} inputs = features["inputs"] target_space = features["target_space_id"] # encoder_output: <tf.float32>[batch_size, input_length, hidden_dim] # encoder_decoder_attention_bias: <tf.float32>[batch_size, input_length] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) with tf.variable_scope("knowledge"): with tf.name_scope("knowledge_encoding"): # Encode knowledge. # <tf.float32>[batch_size, triple_num, emb_dim] fact_embedding, fact_lengths = self.encode_knowledge_bottom( features) tf.logging.info("Encoded knowledge") with tf.name_scope("knowledge_selection_and_loss"): # Compute knowledge selection and loss. triple_logits, avg_triple_selection_loss, knowledge_encoder_output, transe_loss = self.compute_knowledge_selection_and_loss( features, encoder_output, fact_embedding, fact_lengths, hparams.margin, hparams.num_negative_samples) losses["kb_loss"] = avg_triple_selection_loss losses["transe_loss"] = transe_loss if hparams.attend_kb: tf.logging.info("ATTEND_KB is ACTIVE") with tf.name_scope("knowledge_attention"): knowledge_padding = tf.zeros_like(triple_logits, dtype=tf.float32) knowledge_attention_bias = common_attention.attention_bias_ignore_padding( knowledge_padding) encoder_output = tf.concat( [knowledge_encoder_output, encoder_output], 1) encoder_decoder_attention_bias = tf.concat( [knowledge_attention_bias, encoder_decoder_attention_bias], -1) else: tf.logging.info("ATTEND_KB is INACTIVE") targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias ) = transformer.transformer_prepare_decoder(targets, hparams, features=features) decode_kwargs = {} decoder_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets"), losses=losses, **decode_kwargs) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, losses else: return ret
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ #JI: set images shapes imageP = features["imageP"] imageP.set_shape([None,19600]) imageP=tf.reshape(imageP,[-1, img_dim, 100]) if self._num_datashards != 1: raise NotImplementedError("Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams inputs = features["inputs"] batch_size = tf.shape(inputs)[0] target_modality = self._problem_hparams.target_modality if t2t_model.is_class_modality(target_modality): decode_length = 1 else: decode_length = tf.shape(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = tf.shape(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) #JI: send images to encoder if needed with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams, imageP=None) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): #JI: send images to decoder if needed body_outputs = dp(self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, imageP=cache["imageP"], imageP_decoder_self_attention_bias=cache["imageP_decoder_self_attention_bias"]) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access for layer in cache: cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels]) # pylint: enable=protected-access # get image attention bias for decoder img_encoder_padding = common_attention.embedding_to_padding(imageP) imageP_decoder_self_attention_bias = common_attention.attention_bias_ignore_padding(img_encoder_padding) cache["encoder_output"] = encoder_output cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias #get images to cache for input to decoder cache["imageP"] = imageP cache["imageP_decoder_self_attention_bias"] = imageP_decoder_self_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims( common_layers.sample_with_temperature(logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def test_aaa_glow_training(self, depths, split_plans, prior_type): with tf.Graph().as_default(): _, x_mask, _ = self.get_data() x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS), mean=10.0, stddev=3.0, dtype=DTYPE) bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask) hparams = self.get_hparams() hparams.prior_type = prior_type hparams.depths = depths hparams.split_plans = split_plans n_levels = len(hparams.depths.split("/")) kwargs = self.get_kwargs(x_mask, hparams) _ = kwargs.pop("decoder_self_attention_bias") x_inv, _, _, _ = glow.glow( "glow", x, x_mask, bias, inverse=False, init=True, disable_dropout=True, **kwargs) curr_dir = tempfile.mkdtemp() model_path = os.path.join(curr_dir, "model") with tf.Session() as session: saver = tf.train.Saver() session.run(tf.global_variables_initializer()) session.run(x_inv) saver.save(session, model_path) with tf.Graph().as_default(): _, x_mask, _ = self.get_data() x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS), mean=10.0, stddev=3.0, dtype=DTYPE) bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask) hparams = self.get_hparams() hparams.depths = depths hparams.split_plans = split_plans kwargs = self.get_kwargs(x_mask, hparams) _ = kwargs.pop("decoder_self_attention_bias") log_q_z = gops.standard_normal_density(x, x_mask) log_q_z = tf.reduce_sum(log_q_z) / tf.reduce_sum(x_mask) x_inv, logabsdets, log_ps, zs = glow.glow( "glow", x, x_mask, bias, inverse=False, init=False, disable_dropout=True, **kwargs) x_inv_inv, logabsdets_inv, log_ps_inv, _ = glow.glow( "glow", x_inv, x_mask, bias, inverse=True, split_zs=zs, init=False, disable_dropout=True, **kwargs) logabsdets = tf.reduce_sum( logabsdets, axis=0) / tf.reduce_sum(x_mask) logabsdets_inv = tf.reduce_sum( logabsdets_inv, axis=0) / tf.reduce_sum(x_mask) log_ps = tf.reduce_sum(log_ps, axis=0) / tf.reduce_sum(x_mask) log_ps_inv = tf.reduce_sum(log_ps_inv, axis=0) / tf.reduce_sum(x_mask) with tf.Session() as session: saver = tf.train.Saver() saver.restore(session, model_path) (x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps, logabsdets_inv, log_ps_inv) = session.run([ x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps, logabsdets_inv, log_ps_inv]) diff = x - x_inv_inv log_ps_diff = log_ps - log_ps_inv logabsdets_sum = logabsdets + logabsdets_inv self.assertEqual( x_inv.shape, (BATCH_SIZE, TARGET_LENGTH//(2**(n_levels-1)), N_CHANNELS)) print (np.max(np.abs(diff))) print (np.max(np.abs(log_ps_diff))) print (np.max(np.abs(logabsdets_sum))) self.assertTrue(np.allclose(diff, 0.0, atol=1e-4), msg=np.max(np.abs(diff))) self.assertTrue(np.allclose(log_ps_diff, 0.0, atol=1e-4), msg=np.max(np.abs(log_ps_diff))) self.assertTrue(np.allclose(logabsdets_sum, 0.0, atol=1e-4), msg=np.max(np.abs(logabsdets_sum)))