示例#1
0
def prepare_permutation(batch, dataset, tgt_vocab_size):
    """Transform a batch dictionary into a dataclass standard format
    for the transformer to process

    Arguments:

    batch: dict of tf.Tensors
        a dictionary that contains tensors from a tfrecord dataset;
        this function assumes region-features are used
    dataset: str
        type of dataset (captioning or wmt)        
    tgt_vocab_size: tf.Tensor
        the number of words in the target vocabulary of the model; used in order
        to calculate labels for the language model logits

    Returns:

    inputs: TransformerInput
        the input to be passed into a transformer model with attributes
        necessary for also computing the loss function"""

    # process the dataset batch dictionary into the standard
    # model input format
    if dataset == 'captioning':
        prepare_batch = prepare_batch_captioning
    elif dataset in ['wmt', 'django', 'gigaword']:
        prepare_batch = prepare_batch_wmt
    inputs = prepare_batch(batch)

    # the order is fixed
    if dataset == 'captioning':
        bt, bw = batch['token_indicators'], batch['words']
    elif dataset in ['wmt', 'django', 'gigaword']:
        bt, bw = batch['decoder_token_indicators'], batch['decoder_words']
    inputs[5] = get_permutation(bt, bw, tf.constant('l2r'))

    # convert the permutation to absolute and relative positions
    inputs[6] = inputs[5][:, :-1, :-1]
    inputs[7] = permutation_to_relative(inputs[5])

    # convert the permutation to label distributions
    # also records the partial absolute position at each decoding time step
    hard_pointer_labels, inputs[10] = permutation_to_pointer(
        inputs[5][:, tf.newaxis, :, :])
    inputs[8] = tf.squeeze(hard_pointer_labels, axis=1)
    inputs[9] = tf.matmul(
        inputs[5][:, 1:, 1:],
        tf.one_hot(inputs[4], tf.cast(tgt_vocab_size, tf.int32)))

    return inputs
示例#2
0
    def visualize_function(b):
        # calculate the ground truth sequence for this batch; and
        # perform beam search using the current model
        # show several model predicted sequences and their likelihoods
        inputs = prepare_batch_for_lm(tf.constant(1), b)
        cap, logp, rel_pos = beam_search(inputs,
                                         model,
                                         dataset_type,
                                         beam_size=beam_size,
                                         max_iterations=50,
                                         return_rel_pos=True)

        pos = tf.argmax(rel_pos, axis=-1, output_type=tf.int32) - 1
        pos = tf.reduce_sum(tf.nn.relu(pos), axis=2)
        pos = tf.one_hot(pos, tf.shape(pos)[2], dtype=tf.float32)

        # select the most likely beam
        cap = cap[:, 0]
        pos = pos[:, 0]

        # the original cap is missing a start token
        inputs = prepare_batch_for_lm(tf.constant(1), b)
        full_cap = tf.concat([tf.fill([tf.shape(cap)[0], 1], 2), cap], 1)
        inputs[0] = full_cap[:, :-1]
        inputs[2] = tf.logical_not(tf.equal(inputs[0], 0))
        inputs[4] = full_cap[:, 1:]

        # todo: make sure this is not transposed
        inputs[5] = pos
        # convert the permutation to absolute and relative positions
        inputs[6] = inputs[5][:, :-1, :-1]
        inputs[7] = permutation_to_relative(inputs[5])

        # convert the permutation to label distributions
        # also records the partial absolute position at each decoding time step
        hard_pointer_labels, inputs[10] = permutation_to_pointer(
            inputs[5][:, tf.newaxis, :, :])
        inputs[8] = tf.squeeze(hard_pointer_labels, axis=1)
        inputs[9] = tf.matmul(
            inputs[5][:, 1:, 1:],
            tf.one_hot(inputs[4], tf.cast(vocabs[-1].size(), tf.int32)))

        # visuals is a list of attention mechanism scores
        _, visuals = model.visualize(inputs)

        # return the last attention layer's attention head 0
        return full_cap, tf.cast(pos,
                                 tf.int32), visuals[-1][:, 0, :, :], inputs[-1]
    def visualize_function(b):
        # calculate the ground truth sequence for this batch; and
        # perform beam search using the current model
        # show several model predicted sequences and their likelihoods
        inputs = prepare_batch_for_lm(tf.constant(1), b)
        cap, logp, rel_pos = beam_search(
            inputs, model, dataset_type,
            beam_size=beam_size, max_iterations=50, return_rel_pos=True)

        pos = tf.argmax(rel_pos, axis=-1, output_type=tf.int32) - 1
        original_pos = tf.reduce_sum(tf.nn.relu(pos), axis=2)
        pos = tf.one_hot(original_pos, tf.shape(original_pos)[2], dtype=tf.float32)

        # select the most likely beam
        cap = cap[:, 0]
        pos = pos[:, 0]

        # the original cap is missing a start token
        inputs = prepare_batch_for_lm(tf.constant(1), b)
        full_cap = tf.concat([tf.fill([tf.shape(cap)[0], 1], 2), cap], 1)
        inputs[0] = full_cap[:, :-1]
        inputs[2] = tf.logical_not(tf.equal(inputs[0], 0))
        inputs[4] = full_cap[:,  1:]

        # todo: make sure this is not transposed
        inputs[5] = pos
        # convert the permutation to absolute and relative positions
        inputs[6] = inputs[5][:, :-1, :-1]
        inputs[7] = permutation_to_relative(inputs[5])

        # convert the permutation to label distributions
        # also records the partial absolute position at each decoding time step
        hard_pointer_labels, inputs[10] = \
            permutation_to_pointer(inputs[5][:, tf.newaxis, :, :])
        inputs[8] = tf.squeeze(hard_pointer_labels, axis=1)
        inputs[9] = tf.matmul(inputs[5][:, 1:, 1:],
            tf.one_hot(inputs[4], tf.cast(vocabs[-1].size(), tf.int32)))

        # perturb the image features by removing some of them
        original_mask = inputs[3]
        range_i = tf.range(tf.shape(original_mask)[1])[tf.newaxis]
        original_pos = original_pos[:, 0]
        out_pos = original_pos[:, tf.newaxis]

        with tf.control_dependencies(inputs):

            for loc_i in tf.range(tf.shape(original_mask)[1]):

                # set the shape of out_pos
                tf.autograph.experimental.set_loop_options(
                    shape_invariants=[(
                        out_pos, tf.TensorShape([None, None, None]))])

                # create a mask that eliminates exactly one feature
                inputs[3] = tf.logical_and(
                    original_mask, tf.not_equal(range_i, loc_i))

                # use searched adaptive order to
                # check how the predicted order changes
                perturb_rel_pos = adaptive_search(
                    inputs,
                    model,
                    dataset_type,
                    beam_size=beam_size,
                    max_iterations=50,
                    return_rel_pos=True)[2]

                # convert the rel_pos matrix to an array of positions
                perturb_pos = tf.argmax(
                    perturb_rel_pos, axis=-1, output_type=tf.int32) - 1
                perturb_pos = tf.reduce_sum(
                    tf.nn.relu(perturb_pos), axis=2)[:, 0]

                # if extra pad tokens are present at the end they will be removed
                # by adaptive search, and this places them back
                additional_pad_i = tf.range(tf.shape(original_pos)[1])[tf.newaxis]
                additional_pad_i = tf.broadcast_to(additional_pad_i, tf.shape(original_pos))
                additional_pad_i = additional_pad_i[:, tf.shape(perturb_pos)[1]:]
                perturb_pos = tf.concat([perturb_pos, additional_pad_i], 1)

                # add to the set of positions
                with tf.control_dependencies([out_pos, perturb_pos]):
                    out_pos = tf.concat([
                        out_pos, perturb_pos[:, tf.newaxis]], 1)

            # return the perturbed sequences and the Faster-RCNN boxes
            return full_cap, tf.cast(pos, tf.int32), out_pos, inputs[-1]
def prepare_permutation(batch,
                        tgt_vocab_size,
                        order,
                        dataset,
                        policy_gradient,
                        decoder=None):
    """Transform a batch dictionary into a dataclass standard format
    for the transformer to process

    Arguments:

    batch: dict of tf.Tensors
        a dictionary that contains tensors from a tfrecord dataset;
        this function assumes region-features are used
    tgt_vocab_size: tf.Tensor
        the number of words in the target vocabulary of the model; used in order
        to calculate labels for the language model logits
    order: str or callable
        the autoregressive ordering to train Transformer-InDIGO using;
        l2r or r2l for now, will support soft orders later
    dataset: str
        type of dataset (captioning or wmt)
    policy_gradient:
        whether to use policy gradient for training
        choices:
            none: (no policy gradient)
            with_bvn: use policy gradient with probabilities of
                hard permutations based on Berkhoff von Neumann decomposition
                of soft permutation
            without_bvn: after applying Hungarian algorithm on soft
                permutation to obtain hard permutations, the probabilities of hard
                permutations are proportionally based on Gumbel-Matching distribution
                i.e. exp(<X,P>_F), see https://arxiv.org/abs/1802.08665)

    Returns:

    inputs: TransformerInput
        the input to be passed into a transformer model with attributes
        necessary for also computing the loss function"""

    # process the dataset batch dictionary into the standard
    # model input format

    if dataset == 'captioning':
        words = batch['words']
        mask = batch['token_indicators']
        prepare_batch_for_lm = prepare_batch_for_lm_captioning
        prepare_batch_for_pt = prepare_batch_for_pt_captioning
    elif dataset in ['wmt', 'django', 'gigaword']:
        words = batch['decoder_words']
        mask = batch['decoder_token_indicators']
        prepare_batch_for_lm = prepare_batch_for_lm_wmt
        prepare_batch_for_pt = prepare_batch_for_pt_wmt

    inputs = prepare_batch_for_lm(tf.constant(1), batch)
    permu_inputs = None
    # the order is fixed
    if order in ['r2l', 'l2r', 'rare', 'common', 'test']:
        inputs[5] = get_permutation(mask, words, tf.constant(order))

    # pass the training example through the permutation transformer
    # to obtain a doubly stochastic matrix
    if isinstance(order, tf.keras.Model):  # corresponds to soft orderings
        if policy_gradient != 'without_bvn':
            inputs[5] = order(prepare_batch_for_pt(tf.constant(True),
                                                   tf.constant(1), batch), training=True)
        else:
            permu_inputs = prepare_batch_for_pt(tf.constant(True),
                                                tf.constant(1), batch)
            inputs[5], activations, kl, log_nom, log_denom = \
                order(permu_inputs, training=True)
            permu_inputs[-6] = activations
            permu_inputs[-5] = kl
            permu_inputs[-4] = log_nom - log_denom

    # pass the training example through the permutation transformer
    # to obtain a doubly stochastic matrix
    if order == 'sao' and decoder is not None:
        cap, logp, rel_pos = adaptive_search(
            inputs, decoder, dataset,
            beam_size=8, max_iterations=200, return_rel_pos=True)
        pos = tf.argmax(rel_pos, axis=-1, output_type=tf.int32) - 1
        pos = tf.reduce_sum(tf.nn.relu(pos), axis=2)
        pos = tf.one_hot(pos, tf.shape(pos)[2], dtype=tf.float32)
        ind = tf.random.uniform([tf.shape(pos)[0], 1], maxval=7, dtype=tf.int32)
        # todo: make sure this is not transposed
        inputs[5] = tf.squeeze(tf.gather(pos, ind, batch_dims=1), 1)

    if policy_gradient == 'with_bvn':
        raise NotImplementedError
    elif policy_gradient == 'without_bvn':
        inputs[5] = tf.stop_gradient(inputs[5])

    # convert the permutation to absolute and relative positions
    inputs[6] = inputs[5][:, :-1, :-1]
    inputs[7] = permutation_to_relative(inputs[5])

    # convert the permutation to label distributions
    # also records the partial absolute position at each decoding time step
    hard_pointer_labels, inputs[10] = permutation_to_pointer(inputs[5][:, tf.newaxis, :, :])
    inputs[8] = tf.squeeze(hard_pointer_labels, axis=1)
    inputs[9] = tf.matmul(inputs[5][
                          :, 1:, 1:], tf.one_hot(inputs[4], tf.cast(tgt_vocab_size, tf.int32)))

    return inputs, permu_inputs