def neural_gpu_body(inputs, hparams, name=None): """The core Neural GPU.""" with tf.variable_scope(name, "neural_gpu"): def step(state, inp): # pylint: disable=missing-docstring x = tf.nn.dropout(state, 1.0 - hparams.dropout) for layer in range(hparams.num_hidden_layers): x = common_layers.conv_gru( x, (hparams.kernel_height, hparams.kernel_width), hparams.hidden_size, name="cgru_%d" % layer) # Padding input is zeroed-out in the modality, we check this by summing. padding_inp = tf.less(tf.reduce_sum(tf.abs(inp), axis=[1, 2]), 0.00001) new_state = tf.where(padding_inp, state, x) # No-op where inp is padding. return new_state return tf.foldl( step, tf.transpose(inputs, [1, 0, 2, 3]), initializer=inputs, parallel_iterations=1, swap_memory=True)
def elements_model(elements_texts_enc, feature_map, output_size, elements_mask, ref_enc, flags): """The part of the model that processes the elements text and boxes. This assumes that the text has already been preprocessed with the text_model. Even if you are only using the elements and not the referring expression, you should probably use the ref_elements_model since that also handles preprocessing with the text_model. Args: elements_texts_enc: The elements text encoded by the text_model. Size: [batch_size * elements_per_query, text_embed_size] feature_map: Features used by the model. output_size: Desired output size of the encoding. Format: [length, width, depth] elements_mask: Mask for what elements items exist in the input. ref_enc: The referring expression encoded by the text_model. [batch_size, text_embed_size] flags: The input Flags. Returns: The encoding of the elements data. """ with tf.variable_scope('elements_model'): elements_item_size = output_size[2] if flags.use_elements_boxes: elements_boxes = tf.identity(feature_map[ELEMENTS_BOX_ID], ELEMENTS_BOX_ID) flat_elements_boxes = tf.boolean_mask(elements_boxes, elements_mask) else: elements_boxes = None flat_elements_boxes = None if ref_enc is not None: ref_enc_tile = tile_ref_enc_to_elements(ref_enc, elements_mask) elements_ref_match_enc = None if flags.use_elements_ref_match: elements_ref_match = tf.identity( feature_map[ELEMENTS_REF_MATCH_ID], ELEMENTS_REF_MATCH_ID) tf.summary.text('elements_ref_match', elements_ref_match) flat_elements_ref_match = tf.boolean_mask(elements_ref_match, elements_mask) elements_ref_match_enc = text_model( flat_elements_ref_match, flags.pretrained_elements_ref_match_model) # For combinding the element with the refering expression. if flags.merge_ref_elements_method == 'combine' and (ref_enc is not None): elements_enc = tf.concat( filter_none([ elements_texts_enc, flat_elements_boxes, ref_enc_tile, elements_ref_match_enc ]), 1) elements_enc = tf.layers.dense(elements_enc, elements_item_size * 2, tf.nn.relu) else: # Paper results elements_enc = tf.concat( filter_none([ elements_texts_enc, flat_elements_boxes, elements_ref_match_enc ]), 1) elements_enc = tf.layers.dense(elements_enc, elements_item_size, tf.nn.relu) neighbor_embed = None if flags.use_elements_neighbors: neighbor_embed = calc_neighbor_embed( feature_map[ELEMENTS_NEIGHBORS_ID], elements_enc, elements_mask) elements_enc = tf.concat(filter_none([elements_enc, neighbor_embed]), 1) elements_enc = tf.layers.dense(elements_enc, elements_item_size, tf.nn.relu) attend_in = elements_enc # "DNN" elements_enc = tf.nn.dropout(elements_enc, flags.elements_keep_prob) elements_enc = tf.layers.dense(elements_enc, elements_item_size, tf.nn.relu) elements_enc = tf.nn.dropout(elements_enc, flags.elements_keep_prob) elements_enc = tf.layers.dense(elements_enc, elements_item_size) elements_enc_pre_atten = elements_enc if 'Atten' in flags.merge_ref_elements_method and (ref_enc is not None): with tf.variable_scope('attention'): if elements_texts_enc is None: # Prepad with 0s so the box embedding won't overlap with the ref_enc. single_dot_concat = tf.zeros([ tf.shape(flat_elements_boxes)[0], ref_enc.get_shape().as_list()[1] ]) else: single_dot_concat = elements_texts_enc single_dot_in = tf.concat( filter_none([ single_dot_concat, flat_elements_boxes, neighbor_embed, elements_ref_match_enc, ]), 1) single_dot_in = tf.concat( [single_dot_in, tf.ones([tf.shape(single_dot_in)[0], 1])], 1) attention_mask = attention(ref_enc, attend_in, single_dot_in, elements_mask, True, flags.merge_ref_elements_method, flags) attention_mask = tf.expand_dims(attention_mask, 1) elements_enc *= attention_mask # Projects the element embeddings into a 2d feature map. if flags.elements_proj_mode != 'tile': with tf.variable_scope('elements_proj'): # Projects the elements text onto the image feature map # on the corresponding bounding boxes. assert_op = tf.Assert(tf.equal( output_size[0], output_size[1]), [ 'Assumes height and width are the same.', feature_map[ELEMENTS_BOX_ID] ]) with tf.control_dependencies([assert_op]): if flags.proj_elements_memop: # Iterate through all bounding boxes and embeddings to create # embedded bounding boxes and sum to result vector iterately elements_enc = undo_mask(elements_enc, elements_mask) fold_elms = tf.transpose( tf.concat([elements_enc, elements_boxes], 2), [1, 0, 2]) initializer = tf.zeros([tf.shape(elements_mask)[0]] + output_size) def fold_fn(total, fold_elm): elements_enc_boxes = tf.split( fold_elm, [ tf.shape(elements_enc)[2], tf.shape(elements_boxes)[2] ], 1) return total + get_filled_rect( elements_enc_boxes[1], elements_enc_boxes[0], output_size[0], flags.elements_proj_mode) elements_enc = tf.foldl(fold_fn, fold_elms, initializer=initializer, swap_memory=True, parallel_iterations=2) else: # Create embedding of all bb then reduce sum elements_enc = get_filled_rect( flat_elements_boxes, elements_enc, output_size[0], flags.elements_proj_mode) elements_enc = undo_mask(elements_enc, elements_mask) elements_enc = tf.reduce_sum(elements_enc, axis=1) # Turn sum into average. mask_sum = tf.cast( tf.reduce_sum(tf.cast(elements_mask, tf.uint8), 1), tf.float32) mask_sum = tf.reshape(mask_sum, [-1, 1, 1, 1]) mask_sum = tf.where(tf.equal(mask_sum, 0), tf.ones_like(mask_sum), mask_sum) elements_enc /= mask_sum tf.summary.histogram('elements_enc', elements_enc) elements_enc_for_disp = tf.reduce_mean(elements_enc, 3, keepdims=True) tf.summary.image('elements_enc_for_disp', elements_enc_for_disp, 4) else: # Undo the mask for feature mapping sequence_elements_enc = undo_mask(elements_enc, elements_mask) elements_enc = tf.reduce_mean(sequence_elements_enc, axis=1) tf.summary.histogram('elements_enc', elements_enc) if flags.elements_3d_output: elements_enc = tile_to_image(elements_enc, output_size) if flags.elements_3d_output: elements_enc.set_shape( [None, output_size[0], output_size[1], elements_item_size]) # Last CNN layer of elements model if flags.elements_3d_output and flags.elements_cnn: elements_enc = tf.layers.conv2d(elements_enc, elements_enc.shape[3], 3, padding='SAME', activation=tf.nn.relu, strides=1) elements_enc = tf.nn.dropout(elements_enc, flags.elements_keep_prob) elements_enc = tf.layers.conv2d(elements_enc, elements_enc.shape[3], 3, padding='SAME', activation=None, strides=1) return elements_enc, elements_enc_pre_atten