def neural_gpu_body(inputs, hparams, name=None):
  """The core Neural GPU."""
  with tf.variable_scope(name, "neural_gpu"):

    def step(state, inp):  # pylint: disable=missing-docstring
      x = tf.nn.dropout(state, 1.0 - hparams.dropout)
      for layer in range(hparams.num_hidden_layers):
        x = common_layers.conv_gru(
            x, (hparams.kernel_height, hparams.kernel_width),
            hparams.hidden_size,
            name="cgru_%d" % layer)
      # Padding input is zeroed-out in the modality, we check this by summing.
      padding_inp = tf.less(tf.reduce_sum(tf.abs(inp), axis=[1, 2]), 0.00001)
      new_state = tf.where(padding_inp, state, x)  # No-op where inp is padding.
      return new_state

    return tf.foldl(
        step,
        tf.transpose(inputs, [1, 0, 2, 3]),
        initializer=inputs,
        parallel_iterations=1,
        swap_memory=True)
Exemple #2
0
def elements_model(elements_texts_enc, feature_map, output_size, elements_mask,
                   ref_enc, flags):
    """The part of the model that processes the elements text and boxes.

  This assumes that the text has already been preprocessed with the text_model.
  Even if you are only using the elements and not the referring expression, you
  should probably use the ref_elements_model since that also handles
  preprocessing with the text_model.

  Args:
    elements_texts_enc: The elements text encoded by the text_model. Size:
      [batch_size * elements_per_query, text_embed_size]
    feature_map: Features used by the model.
    output_size: Desired output size of the encoding. Format: [length, width,
      depth]
    elements_mask: Mask for what elements items exist in the input.
    ref_enc: The referring expression encoded by the text_model. [batch_size,
      text_embed_size]
    flags: The input Flags.

  Returns:
    The encoding of the elements data.
  """

    with tf.variable_scope('elements_model'):
        elements_item_size = output_size[2]

        if flags.use_elements_boxes:
            elements_boxes = tf.identity(feature_map[ELEMENTS_BOX_ID],
                                         ELEMENTS_BOX_ID)
            flat_elements_boxes = tf.boolean_mask(elements_boxes,
                                                  elements_mask)
        else:
            elements_boxes = None
            flat_elements_boxes = None

        if ref_enc is not None:
            ref_enc_tile = tile_ref_enc_to_elements(ref_enc, elements_mask)

        elements_ref_match_enc = None
        if flags.use_elements_ref_match:
            elements_ref_match = tf.identity(
                feature_map[ELEMENTS_REF_MATCH_ID], ELEMENTS_REF_MATCH_ID)
            tf.summary.text('elements_ref_match', elements_ref_match)
            flat_elements_ref_match = tf.boolean_mask(elements_ref_match,
                                                      elements_mask)

            elements_ref_match_enc = text_model(
                flat_elements_ref_match,
                flags.pretrained_elements_ref_match_model)

        # For combinding the element with the refering expression.
        if flags.merge_ref_elements_method == 'combine' and (ref_enc
                                                             is not None):
            elements_enc = tf.concat(
                filter_none([
                    elements_texts_enc, flat_elements_boxes, ref_enc_tile,
                    elements_ref_match_enc
                ]), 1)
            elements_enc = tf.layers.dense(elements_enc,
                                           elements_item_size * 2, tf.nn.relu)
        else:
            # Paper results
            elements_enc = tf.concat(
                filter_none([
                    elements_texts_enc, flat_elements_boxes,
                    elements_ref_match_enc
                ]), 1)
            elements_enc = tf.layers.dense(elements_enc, elements_item_size,
                                           tf.nn.relu)

        neighbor_embed = None
        if flags.use_elements_neighbors:
            neighbor_embed = calc_neighbor_embed(
                feature_map[ELEMENTS_NEIGHBORS_ID], elements_enc,
                elements_mask)

        elements_enc = tf.concat(filter_none([elements_enc, neighbor_embed]),
                                 1)

        elements_enc = tf.layers.dense(elements_enc, elements_item_size,
                                       tf.nn.relu)

        attend_in = elements_enc

        # "DNN"
        elements_enc = tf.nn.dropout(elements_enc, flags.elements_keep_prob)
        elements_enc = tf.layers.dense(elements_enc, elements_item_size,
                                       tf.nn.relu)
        elements_enc = tf.nn.dropout(elements_enc, flags.elements_keep_prob)
        elements_enc = tf.layers.dense(elements_enc, elements_item_size)

        elements_enc_pre_atten = elements_enc

        if 'Atten' in flags.merge_ref_elements_method and (ref_enc
                                                           is not None):
            with tf.variable_scope('attention'):
                if elements_texts_enc is None:
                    # Prepad with 0s so the box embedding won't overlap with the ref_enc.
                    single_dot_concat = tf.zeros([
                        tf.shape(flat_elements_boxes)[0],
                        ref_enc.get_shape().as_list()[1]
                    ])
                else:
                    single_dot_concat = elements_texts_enc
                single_dot_in = tf.concat(
                    filter_none([
                        single_dot_concat,
                        flat_elements_boxes,
                        neighbor_embed,
                        elements_ref_match_enc,
                    ]), 1)
                single_dot_in = tf.concat(
                    [single_dot_in,
                     tf.ones([tf.shape(single_dot_in)[0], 1])], 1)

                attention_mask = attention(ref_enc, attend_in, single_dot_in,
                                           elements_mask, True,
                                           flags.merge_ref_elements_method,
                                           flags)

                attention_mask = tf.expand_dims(attention_mask, 1)

                elements_enc *= attention_mask

        # Projects the element embeddings into a 2d feature map.
        if flags.elements_proj_mode != 'tile':
            with tf.variable_scope('elements_proj'):
                # Projects the elements text onto the image feature map
                # on the corresponding bounding boxes.

                assert_op = tf.Assert(tf.equal(
                    output_size[0], output_size[1]), [
                        'Assumes height and width are the same.',
                        feature_map[ELEMENTS_BOX_ID]
                    ])
                with tf.control_dependencies([assert_op]):
                    if flags.proj_elements_memop:
                        # Iterate through all bounding boxes and embeddings to create
                        # embedded bounding boxes and sum to result vector iterately
                        elements_enc = undo_mask(elements_enc, elements_mask)

                        fold_elms = tf.transpose(
                            tf.concat([elements_enc, elements_boxes], 2),
                            [1, 0, 2])

                        initializer = tf.zeros([tf.shape(elements_mask)[0]] +
                                               output_size)

                        def fold_fn(total, fold_elm):
                            elements_enc_boxes = tf.split(
                                fold_elm, [
                                    tf.shape(elements_enc)[2],
                                    tf.shape(elements_boxes)[2]
                                ], 1)
                            return total + get_filled_rect(
                                elements_enc_boxes[1], elements_enc_boxes[0],
                                output_size[0], flags.elements_proj_mode)

                        elements_enc = tf.foldl(fold_fn,
                                                fold_elms,
                                                initializer=initializer,
                                                swap_memory=True,
                                                parallel_iterations=2)

                    else:
                        # Create embedding of all bb then reduce sum
                        elements_enc = get_filled_rect(
                            flat_elements_boxes, elements_enc, output_size[0],
                            flags.elements_proj_mode)

                        elements_enc = undo_mask(elements_enc, elements_mask)

                        elements_enc = tf.reduce_sum(elements_enc, axis=1)

                # Turn sum into average.
                mask_sum = tf.cast(
                    tf.reduce_sum(tf.cast(elements_mask, tf.uint8), 1),
                    tf.float32)
                mask_sum = tf.reshape(mask_sum, [-1, 1, 1, 1])
                mask_sum = tf.where(tf.equal(mask_sum, 0),
                                    tf.ones_like(mask_sum), mask_sum)
                elements_enc /= mask_sum
                tf.summary.histogram('elements_enc', elements_enc)

                elements_enc_for_disp = tf.reduce_mean(elements_enc,
                                                       3,
                                                       keepdims=True)
                tf.summary.image('elements_enc_for_disp',
                                 elements_enc_for_disp, 4)
        else:
            # Undo the mask for feature mapping
            sequence_elements_enc = undo_mask(elements_enc, elements_mask)

            elements_enc = tf.reduce_mean(sequence_elements_enc, axis=1)
            tf.summary.histogram('elements_enc', elements_enc)

            if flags.elements_3d_output:
                elements_enc = tile_to_image(elements_enc, output_size)

        if flags.elements_3d_output:
            elements_enc.set_shape(
                [None, output_size[0], output_size[1], elements_item_size])

        # Last CNN layer of elements model
        if flags.elements_3d_output and flags.elements_cnn:
            elements_enc = tf.layers.conv2d(elements_enc,
                                            elements_enc.shape[3],
                                            3,
                                            padding='SAME',
                                            activation=tf.nn.relu,
                                            strides=1)
            elements_enc = tf.nn.dropout(elements_enc,
                                         flags.elements_keep_prob)
            elements_enc = tf.layers.conv2d(elements_enc,
                                            elements_enc.shape[3],
                                            3,
                                            padding='SAME',
                                            activation=None,
                                            strides=1)

        return elements_enc, elements_enc_pre_atten