def prepare_permutation(batch, dataset, tgt_vocab_size): """Transform a batch dictionary into a dataclass standard format for the transformer to process Arguments: batch: dict of tf.Tensors a dictionary that contains tensors from a tfrecord dataset; this function assumes region-features are used dataset: str type of dataset (captioning or wmt) tgt_vocab_size: tf.Tensor the number of words in the target vocabulary of the model; used in order to calculate labels for the language model logits Returns: inputs: TransformerInput the input to be passed into a transformer model with attributes necessary for also computing the loss function""" # process the dataset batch dictionary into the standard # model input format if dataset == 'captioning': prepare_batch = prepare_batch_captioning elif dataset in ['wmt', 'django', 'gigaword']: prepare_batch = prepare_batch_wmt inputs = prepare_batch(batch) # the order is fixed if dataset == 'captioning': bt, bw = batch['token_indicators'], batch['words'] elif dataset in ['wmt', 'django', 'gigaword']: bt, bw = batch['decoder_token_indicators'], batch['decoder_words'] inputs[5] = get_permutation(bt, bw, tf.constant('l2r')) # convert the permutation to absolute and relative positions inputs[6] = inputs[5][:, :-1, :-1] inputs[7] = permutation_to_relative(inputs[5]) # convert the permutation to label distributions # also records the partial absolute position at each decoding time step hard_pointer_labels, inputs[10] = permutation_to_pointer( inputs[5][:, tf.newaxis, :, :]) inputs[8] = tf.squeeze(hard_pointer_labels, axis=1) inputs[9] = tf.matmul( inputs[5][:, 1:, 1:], tf.one_hot(inputs[4], tf.cast(tgt_vocab_size, tf.int32))) return inputs
def visualize_function(b): # calculate the ground truth sequence for this batch; and # perform beam search using the current model # show several model predicted sequences and their likelihoods inputs = prepare_batch_for_lm(tf.constant(1), b) cap, logp, rel_pos = beam_search(inputs, model, dataset_type, beam_size=beam_size, max_iterations=50, return_rel_pos=True) pos = tf.argmax(rel_pos, axis=-1, output_type=tf.int32) - 1 pos = tf.reduce_sum(tf.nn.relu(pos), axis=2) pos = tf.one_hot(pos, tf.shape(pos)[2], dtype=tf.float32) # select the most likely beam cap = cap[:, 0] pos = pos[:, 0] # the original cap is missing a start token inputs = prepare_batch_for_lm(tf.constant(1), b) full_cap = tf.concat([tf.fill([tf.shape(cap)[0], 1], 2), cap], 1) inputs[0] = full_cap[:, :-1] inputs[2] = tf.logical_not(tf.equal(inputs[0], 0)) inputs[4] = full_cap[:, 1:] # todo: make sure this is not transposed inputs[5] = pos # convert the permutation to absolute and relative positions inputs[6] = inputs[5][:, :-1, :-1] inputs[7] = permutation_to_relative(inputs[5]) # convert the permutation to label distributions # also records the partial absolute position at each decoding time step hard_pointer_labels, inputs[10] = permutation_to_pointer( inputs[5][:, tf.newaxis, :, :]) inputs[8] = tf.squeeze(hard_pointer_labels, axis=1) inputs[9] = tf.matmul( inputs[5][:, 1:, 1:], tf.one_hot(inputs[4], tf.cast(vocabs[-1].size(), tf.int32))) # visuals is a list of attention mechanism scores _, visuals = model.visualize(inputs) # return the last attention layer's attention head 0 return full_cap, tf.cast(pos, tf.int32), visuals[-1][:, 0, :, :], inputs[-1]
def visualize_function(b): # calculate the ground truth sequence for this batch; and # perform beam search using the current model # show several model predicted sequences and their likelihoods inputs = prepare_batch_for_lm(tf.constant(1), b) cap, logp, rel_pos = beam_search( inputs, model, dataset_type, beam_size=beam_size, max_iterations=50, return_rel_pos=True) pos = tf.argmax(rel_pos, axis=-1, output_type=tf.int32) - 1 original_pos = tf.reduce_sum(tf.nn.relu(pos), axis=2) pos = tf.one_hot(original_pos, tf.shape(original_pos)[2], dtype=tf.float32) # select the most likely beam cap = cap[:, 0] pos = pos[:, 0] # the original cap is missing a start token inputs = prepare_batch_for_lm(tf.constant(1), b) full_cap = tf.concat([tf.fill([tf.shape(cap)[0], 1], 2), cap], 1) inputs[0] = full_cap[:, :-1] inputs[2] = tf.logical_not(tf.equal(inputs[0], 0)) inputs[4] = full_cap[:, 1:] # todo: make sure this is not transposed inputs[5] = pos # convert the permutation to absolute and relative positions inputs[6] = inputs[5][:, :-1, :-1] inputs[7] = permutation_to_relative(inputs[5]) # convert the permutation to label distributions # also records the partial absolute position at each decoding time step hard_pointer_labels, inputs[10] = \ permutation_to_pointer(inputs[5][:, tf.newaxis, :, :]) inputs[8] = tf.squeeze(hard_pointer_labels, axis=1) inputs[9] = tf.matmul(inputs[5][:, 1:, 1:], tf.one_hot(inputs[4], tf.cast(vocabs[-1].size(), tf.int32))) # perturb the image features by removing some of them original_mask = inputs[3] range_i = tf.range(tf.shape(original_mask)[1])[tf.newaxis] original_pos = original_pos[:, 0] out_pos = original_pos[:, tf.newaxis] with tf.control_dependencies(inputs): for loc_i in tf.range(tf.shape(original_mask)[1]): # set the shape of out_pos tf.autograph.experimental.set_loop_options( shape_invariants=[( out_pos, tf.TensorShape([None, None, None]))]) # create a mask that eliminates exactly one feature inputs[3] = tf.logical_and( original_mask, tf.not_equal(range_i, loc_i)) # use searched adaptive order to # check how the predicted order changes perturb_rel_pos = adaptive_search( inputs, model, dataset_type, beam_size=beam_size, max_iterations=50, return_rel_pos=True)[2] # convert the rel_pos matrix to an array of positions perturb_pos = tf.argmax( perturb_rel_pos, axis=-1, output_type=tf.int32) - 1 perturb_pos = tf.reduce_sum( tf.nn.relu(perturb_pos), axis=2)[:, 0] # if extra pad tokens are present at the end they will be removed # by adaptive search, and this places them back additional_pad_i = tf.range(tf.shape(original_pos)[1])[tf.newaxis] additional_pad_i = tf.broadcast_to(additional_pad_i, tf.shape(original_pos)) additional_pad_i = additional_pad_i[:, tf.shape(perturb_pos)[1]:] perturb_pos = tf.concat([perturb_pos, additional_pad_i], 1) # add to the set of positions with tf.control_dependencies([out_pos, perturb_pos]): out_pos = tf.concat([ out_pos, perturb_pos[:, tf.newaxis]], 1) # return the perturbed sequences and the Faster-RCNN boxes return full_cap, tf.cast(pos, tf.int32), out_pos, inputs[-1]
def prepare_permutation(batch, tgt_vocab_size, order, dataset, policy_gradient, decoder=None): """Transform a batch dictionary into a dataclass standard format for the transformer to process Arguments: batch: dict of tf.Tensors a dictionary that contains tensors from a tfrecord dataset; this function assumes region-features are used tgt_vocab_size: tf.Tensor the number of words in the target vocabulary of the model; used in order to calculate labels for the language model logits order: str or callable the autoregressive ordering to train Transformer-InDIGO using; l2r or r2l for now, will support soft orders later dataset: str type of dataset (captioning or wmt) policy_gradient: whether to use policy gradient for training choices: none: (no policy gradient) with_bvn: use policy gradient with probabilities of hard permutations based on Berkhoff von Neumann decomposition of soft permutation without_bvn: after applying Hungarian algorithm on soft permutation to obtain hard permutations, the probabilities of hard permutations are proportionally based on Gumbel-Matching distribution i.e. exp(<X,P>_F), see https://arxiv.org/abs/1802.08665) Returns: inputs: TransformerInput the input to be passed into a transformer model with attributes necessary for also computing the loss function""" # process the dataset batch dictionary into the standard # model input format if dataset == 'captioning': words = batch['words'] mask = batch['token_indicators'] prepare_batch_for_lm = prepare_batch_for_lm_captioning prepare_batch_for_pt = prepare_batch_for_pt_captioning elif dataset in ['wmt', 'django', 'gigaword']: words = batch['decoder_words'] mask = batch['decoder_token_indicators'] prepare_batch_for_lm = prepare_batch_for_lm_wmt prepare_batch_for_pt = prepare_batch_for_pt_wmt inputs = prepare_batch_for_lm(tf.constant(1), batch) permu_inputs = None # the order is fixed if order in ['r2l', 'l2r', 'rare', 'common', 'test']: inputs[5] = get_permutation(mask, words, tf.constant(order)) # pass the training example through the permutation transformer # to obtain a doubly stochastic matrix if isinstance(order, tf.keras.Model): # corresponds to soft orderings if policy_gradient != 'without_bvn': inputs[5] = order(prepare_batch_for_pt(tf.constant(True), tf.constant(1), batch), training=True) else: permu_inputs = prepare_batch_for_pt(tf.constant(True), tf.constant(1), batch) inputs[5], activations, kl, log_nom, log_denom = \ order(permu_inputs, training=True) permu_inputs[-6] = activations permu_inputs[-5] = kl permu_inputs[-4] = log_nom - log_denom # pass the training example through the permutation transformer # to obtain a doubly stochastic matrix if order == 'sao' and decoder is not None: cap, logp, rel_pos = adaptive_search( inputs, decoder, dataset, beam_size=8, max_iterations=200, return_rel_pos=True) pos = tf.argmax(rel_pos, axis=-1, output_type=tf.int32) - 1 pos = tf.reduce_sum(tf.nn.relu(pos), axis=2) pos = tf.one_hot(pos, tf.shape(pos)[2], dtype=tf.float32) ind = tf.random.uniform([tf.shape(pos)[0], 1], maxval=7, dtype=tf.int32) # todo: make sure this is not transposed inputs[5] = tf.squeeze(tf.gather(pos, ind, batch_dims=1), 1) if policy_gradient == 'with_bvn': raise NotImplementedError elif policy_gradient == 'without_bvn': inputs[5] = tf.stop_gradient(inputs[5]) # convert the permutation to absolute and relative positions inputs[6] = inputs[5][:, :-1, :-1] inputs[7] = permutation_to_relative(inputs[5]) # convert the permutation to label distributions # also records the partial absolute position at each decoding time step hard_pointer_labels, inputs[10] = permutation_to_pointer(inputs[5][:, tf.newaxis, :, :]) inputs[8] = tf.squeeze(hard_pointer_labels, axis=1) inputs[9] = tf.matmul(inputs[5][ :, 1:, 1:], tf.one_hot(inputs[4], tf.cast(tgt_vocab_size, tf.int32))) return inputs, permu_inputs