def _repeat_molecules(self, molecules: tf.Tensor, char_seq_length: tf.Tensor, molecule_seq_length: tf.Tensor) -> tf.Tensor: """Repeats molecules to make them the same length as the char sequence.""" del molecule_seq_length # Used for contract only. rate = self.config.downsampling_rate molecules_without_extra_cls = molecules[:, 1:, :] # `repeated`: [batch_size, almost_char_seq_len, molecule_hidden_size] repeated = tf.repeat(molecules_without_extra_cls, repeats=rate, axis=-2) # So far, we've repeated the elements sufficient for any `char_seq_length` # that's a multiple of `downsampling_rate`. Now we account for the last # n elements (n < `downsampling_rate`), i.e. the remainder of floor # division. We do this by repeating the last molecule a few extra times. last_molecule = molecules[:, -1:, :] remainder_length = tf.floormod(char_seq_length, rate) remainder_repeated = tf.repeat( last_molecule, # +1 molecule to compensate for truncation. repeats=remainder_length + rate, axis=-2) # `repeated`: [batch_size, char_seq_len, molecule_hidden_size] return tf.concat([repeated, remainder_repeated], axis=-2)
def _get_spans( seq_length, batch_size, max_span_length, ): """Computes all possible spans of a certain size.""" indexes = _get_indexes(seq_length, batch_size) all_spans = [] for offset in range(max_span_length): end_indexes = indexes + offset start_indexes = _get_valid_indexes(indexes, offset) end_indexes = _get_valid_indexes(end_indexes, offset) spans = tf.stack([start_indexes, end_indexes], axis=-1) all_spans.append(spans) # <int32>[batch_size, num_spans, 2] spans = tf.concat(all_spans, axis=1) num_spans = modeling.get_shape_list(spans, expected_rank=3)[1] batch_dims = tf.expand_dims(tf.range(0, batch_size), axis=1) # <int32>[batch_size, num_spans*2] batch_dims = tf.repeat(batch_dims, repeats=num_spans * 2, axis=1) flat_spans = tf.reshape(spans, shape=(batch_size, num_spans * 2)) # <int32>[batch_size, num_spans*2, 2] flat_spans_2d = tf.stack([batch_dims, flat_spans], axis=2) return spans, num_spans, flat_spans_2d
def _get_relative_position_embeddings( full_position_embeddings, token_type_ids, token_type_vocab_size, seq_length, batch_size, max_position_embeddings, ): """Create position embeddings that restart at every cell.""" col_index = segmented_tensor.IndexMap( token_type_ids[1], token_type_vocab_size[1], batch_dims=1) row_index = segmented_tensor.IndexMap( token_type_ids[2], token_type_vocab_size[2], batch_dims=1) full_index = segmented_tensor.ProductIndexMap(col_index, row_index) position = tf.expand_dims(tf.range(seq_length), axis=0) logging.info("position: %s", position) batched_position = tf.repeat(position, repeats=batch_size, axis=0) logging.info("batched_position: %s", batched_position) logging.info("token_type_ids: %s", token_type_ids[1]) first_position_per_segment = segmented_tensor.reduce_min( batched_position, full_index)[0] first_position = segmented_tensor.gather(first_position_per_segment, full_index) position_embeddings = tf.nn.embedding_lookup( full_position_embeddings, tf.math.minimum(max_position_embeddings - 1, position - first_position)) return position_embeddings
def _expansion_box_field_labels(self, object_classes, object_field, copy_class_id=False): """Expand the labels of a specific object field according to the hierarchy. Args: object_classes: Int64 tensor with the class id for each element in object_field. object_field: Tensor to be expanded. copy_class_id: Boolean to choose whether to use class id values in the output tensor instead of replicating the original values. Returns: A tensor with the result of expanding object_field. """ expanded_indices = tf.gather(self._ancestors_lut, object_classes - _LABEL_OFFSET, axis=0) if copy_class_id: new_object_field = tf.where( expanded_indices > 0)[:, 1] + _LABEL_OFFSET else: new_object_field = tf.repeat(object_field, tf.reduce_sum(expanded_indices, axis=1), axis=0) return new_object_field
def apply_matrix_to_single_image(image, matrix): """applies a homographic matrix to the image Arguments: image [DIM, DIM, 3] -- input image matrix [3, 3]float -- transformation matrix Returns: [DIM, DIM, 3] -- transformed image """ img_dim = image.shape[1] img_dim_even = img_dim % 2 # List Destination Pixel Indices, image centered in 0,0 x = tf.repeat(tf.range(img_dim // 2, -img_dim // 2, -1), img_dim) y = tf.tile(tf.range(-img_dim // 2, img_dim // 2), [img_dim]) z = tf.ones([img_dim * img_dim], dtype='int32') idx = tf.stack([x, y, z]) idx2 = tf.linalg.matmul(matrix, tf.cast(idx, dtype='float32')) # convert to x,y coordinates idx2 = tf.cast(idx2, dtype='int32') # remove all the elements that are outside the image boundaries idx2 = tf.clip_by_value(idx2, -img_dim // 2 + img_dim_even + 1, img_dim // 2) # find original pixel values idx3 = tf.stack([img_dim // 2 - idx2[0, :], img_dim // 2 - 1 + idx2[1, :]]) d = tf.gather_nd(image, tf.transpose(idx3)) return tf.reshape(d, [img_dim, img_dim, 3])
def sqrt_fixed_full(x, config, is_training=True, causal=True): """Full attention matrix with sqrt decomposition.""" bsize = x.shape[0] query, key, value = attention.get_qkv(x, x, x, hidden_size=config.model_size, num_heads=config.num_heads, bias=config.dense_use_bias) head_dim = config.model_size // config.num_heads assert config.max_seq_len % config.max_seg_len == 0 num_seg = config.max_seq_len // config.max_seg_len cur_query = tf.reshape(query, [-1, num_seg, config.max_seg_len, config.num_heads, head_dim]) with tf.variable_scope('pooling_query'): merged_query = pooling_summary(cur_query, axis=2, local_summary=config.local_summary, keepdims=True) cur_key = tf.reshape(key, cur_query.shape) cur_val = tf.reshape(value, cur_query.shape) span_val = attention.dot_product_attention(merged_query, cur_key, cur_val, is_training=is_training, attn_axis=1, dropatt=config.dropatt) span_val = tf.squeeze(span_val, axis=2) with tf.variable_scope('pooling_key'): span_key = pooling_summary(cur_key, axis=2, local_summary=config.local_summary, keepdims=False) local_logits = tf.einsum('bsqhd,bskhd->bsqhk', cur_query, cur_key) if causal: local_mask = get_causal_mask(cur_query, axis=2, is_strict=False) local_mask = tf.expand_dims(local_mask, axis=-2) local_logits += local_mask prev_logits = tf.einsum('bqhd,bkhd->bqhk', query, span_key) if causal: prev_mask = get_causal_mask(cur_query, axis=1, is_strict=True) prev_mask = tf.repeat(prev_mask, [config.max_seg_len] * num_seg, axis=0) prev_logits += tf.expand_dims(prev_mask, axis=1) joint_logits = tf.concat([tf.reshape(local_logits, [bsize, config.max_seq_len, config.num_heads, -1]), prev_logits], axis=-1) attn_weights = attention.float32_softmax(joint_logits, axis=-1) local_att, prev_att = tf.split(attn_weights, [config.max_seg_len, num_seg], axis=-1) if is_training: local_att = tf.nn.dropout(local_att, rate=config.dropatt) local_att = tf.reshape(local_att, [bsize, num_seg, config.max_seg_len, config.num_heads, config.max_seg_len]) local_merged = tf.einsum('bsqhk,bskhd->bsqhd', local_att, cur_val) prev_merged = tf.einsum('bqhk,bkhd->bqhd', prev_att, span_val) joint_merged = prev_merged + tf.reshape(local_merged, prev_merged.shape) output = ops.trail_dense(joint_merged, config.model_size, begin_axis=-2) return output
def _else(detections, class_id, indices): """Else branch for generating detections.""" boxes_cls = tf.gather(boxes, indices) scores_cls = tf.gather(scores, indices) # Select top-scoring boxes in each class and apply non-maximum suppression # (nms) for boxes in the same class. The selected boxes from each class are # then concatenated for the final detection outputs. if use_native_nms: logging.info('Using native nms.') top_detection_idx, scores_cls = tf.image.non_max_suppression_with_scores( boxes_cls, scores_cls, max_boxes_to_draw, iou_threshold=iou_threshold, score_threshold=min_score_thresh, soft_nms_sigma=soft_nms_sigma) scores_cls = tf.expand_dims(scores_cls, axis=1) boxes_cls = tf.gather(boxes_cls, top_detection_idx) top_detections_cls = tf.concat([boxes_cls, scores_cls], axis=1) else: logging.info('Using customized nms.') scores_cls = tf.expand_dims(scores_cls, axis=1) all_detections_cls = tf.concat([boxes_cls, scores_cls], axis=1) top_detection_idx = nms_tf(all_detections_cls, iou_threshold) top_detections_cls = tf.gather(all_detections_cls, top_detection_idx) height = top_detections_cls[:, 2] - top_detections_cls[:, 0] width = top_detections_cls[:, 3] - top_detections_cls[:, 1] top_detections_cls = tf.stack([ top_detections_cls[:, 0] * image_scale, top_detections_cls[:, 1] * image_scale, height * image_scale, width * image_scale, top_detections_cls[:, 4] ], axis=-1) top_detections_cls = tf.stack([ tf.cast(tf.repeat(image_id, tf.size(top_detection_idx)), tf.float32), *tf.unstack(top_detections_cls, 5, axis=1), tf.repeat(class_id + 1.0, tf.size(top_detection_idx)) ], axis=1) detections = tf.concat([detections, top_detections_cls], axis=0) return detections
def noise_like(shape, noise_fn=tf.random_normal, repeat=False, dtype=tf.float32): repeat_noise = lambda: tf.repeat( noise_fn(shape=(1, *shape[1:]), dtype=dtype), repeats=shape[0], axis=0) noise = lambda: noise_fn(shape=shape, dtype=dtype) return repeat_noise() if repeat else noise()
def mu_tf(self, t, X, Y, Z): # M x 1, M x D, M x 1, M x D A, B = self.cartpole.f(X) matmul, multiply, rowsum = self.getTFUtils() uinput = tf.constant(1.0 / self.R) * self.BtXgradV(B, Z) inputs = tf.repeat(tf.reshape(uinput, (-1, 1)), repeats=4, axis=1) return A - multiply(inputs, B) # M x D
def __call__(self, inps): gval = tf.matmul(inps, self.g_weight) topval = tf.math.top_k(gval, self.top_blocks)[0][:, -1] gate = tf.cast((gval >= topval), tf.float32) * gval gate = tf.repeat(gate, self.block_size**2, axis=-1) gate = tf.reshape(gate, (self.hz * 4, self.hz * 2)) gate = gate / (tf.math.reduce_sum(gate) / tf.cast(tf.size(gate), tf.float32)) # we exclude the cost of dense to sparse # we add it to Graph with dummy operation tf.reduce_mean(gate) return self._rows, self._columns, self.values + tf.reduce_mean( gate), self._row_indices, self._row_offsets, self._column_indices
def _gather_and_repeat(tensors, name): """Reorder each tensor using its type_id and copies it once per head.""" if len(tensors) == 1: tensors = tensors * len(type_ids_argsorted) with tf.variable_scope(name): gathered_tensors = [] for tensor, indices in zip(tensors, type_ids_argsorted): gathered_tensors.append(tensor if indices is None else tf. gather(tensor, indices, batch_dims=1)) return tf.repeat(tf.stack(gathered_tensors, axis=1), repeats=type_counts, axis=1)
def decompress_episode_representation(episode_representation): """Decompresses an episode representation into a dataset ID stream. Args: episode_representation: tensor of shape [None, 2]. Its first column represents dataset IDs and its second column represents the number of times they're repeated in the sequence. Returns: 1D tensor, decompressed sequence of dataset IDs. """ episode_representation.set_shape([None, 2]) dataset_ids, repeats = tf.unstack(episode_representation, axis=1) return tf.repeat(dataset_ids, repeats)
def _else(detections, class_id): """Else branch forr generating detections.""" boxes_cls = tf.gather(boxes, indices) scores_cls = tf.gather(scores, indices) # Select top-scoring boxes in each class and apply non-maximum suppression # (nms) for boxes in the same class. The selected boxes from each class are # then concatenated for the final detection outputs. all_detections_cls = tf.concat( [tf.reshape(boxes_cls, [-1, 4]), scores_cls], axis=1) if use_native_nms: top_detection_idx = tf.image.non_max_suppression( all_detections_cls[:, :4], all_detections_cls[:, 4], MAX_DETECTIONS_PER_IMAGE, iou_threshold=0.5) else: top_detection_idx = nms_tf(all_detections_cls, 0.5) top_detections_cls = tf.gather(all_detections_cls, top_detection_idx) height = top_detections_cls[:, 2] - top_detections_cls[:, 0] width = top_detections_cls[:, 3] - top_detections_cls[:, 1] top_detections_cls = tf.stack([ top_detections_cls[:, 0] * image_scale, top_detections_cls[:, 1] * image_scale, height * image_scale, width * image_scale, top_detections_cls[:, 4] ], axis=-1) top_detections_cls = tf.stack([ tf.cast(tf.repeat(image_id, tf.size(top_detection_idx)), tf.float32), *tf.unstack(top_detections_cls, 5, axis=1), tf.repeat(class_id + 1.0, tf.size(top_detection_idx)) ], axis=1) detections = tf.concat([detections, top_detections_cls], axis=0) return detections
def giou(boxlist1, boxlist2, scope=None): """Computes pairwise generalized IOU between two boxlists. Args: boxlist1: BoxList holding N boxes boxlist2: BoxList holding M boxes scope: name scope. Returns: a tensor with shape [N, M] representing the pairwise GIoU loss. """ with tf.name_scope(scope, 'PairwiseGIoU'): n = boxlist1.num_boxes() m = boxlist2.num_boxes() boxes1 = tf.repeat(boxlist1.get(), repeats=m, axis=0) boxes2 = tf.tile(boxlist2.get(), multiples=[n, 1]) return tf.reshape(ops.giou(boxes1, boxes2), [n, m])
def upsample(self, output, stride, tgt_len): """Upsample a hidden state by stride.""" if stride == 1: return output net_config = self.net_config if net_config.separate_cls: cls_output = output[:, :1] output = output[:, 1:] output = tf.repeat(output, repeats=stride, axis=1) if net_config.separate_cls: if FLAGS.truncate_seq: pad_len = stride - 1 output = tf.pad(output, [[0, 0], [0, pad_len], [0, 0]]) else: output = output[:, :tgt_len - 1] output = tf.concat([cls_output, output], axis=1) return output
def loss(model, cartpoleUtil, t_interior, X_interior, t_terminal, X_terminal): ''' Compute total loss for training. Args: model: DGM model object t_interior: sampled time points in the interior of the function's domain X_interior: sampled space points in the interior of the function's domain t_terminal: sampled time points at terminal point (vector of terminal times) X_terminal: sampled space points at terminal time ''' # Loss term #1: PDE # compute function value and derivatives at current sampled points # \frac{\partial u}{\partial t}(t, x) + \Delta u(t, x) - \lambda \| \nabla u(t, x) \|^2 = 0 # => V_t + V_xx - lambda * L2_norm(V_x)^2 matmul, multiply, rowsum = getTFUtils() V = model(t_interior, X_interior) V_t = tf.gradients(V, t_interior)[0] print('V_t=%s' % V_t) # f = phi1 + phi2 const = tf.constant print('X_interior=%s' % X_interior) phi1 = tf.constant(0.5) * quadraticForm(X_interior, Q) A, B = cartpoleUtil.f(X_interior) print('A=%s' % A) print('B=%s' % B) V_x = tf.gradients(V, X_interior)[0] print('V_x=%s' % V_x) Bt_gradV = BtXgradV(B, V_x) print('Bt_gradV=%s' % Bt_gradV) phi2 = const(0.5) * tf.square(Bt_gradV) / const(R * 1.0) print('phi1=%s' % phi1) print('phi2=%s' % phi2) f = phi1 + phi2 # mu^T uinput = const(1.0 / R) * Bt_gradV inputs = tf.repeat(tf.reshape(uinput, (-1, 1)), repeats=D, axis=1) print('inputs=%s' % inputs) mu_t = A - multiply(inputs, B) V_xx = tf.gradients(V_x, X_interior)[0] print('V_t=%s' % V_t) print('f=%s' % f) print('mu_t=%s' % mu_t) print('V_x=%s' % V_x) print('snoise=%s' % snoise) print('V_xx=%s' % V_xx) mul = rowsum(multiply(mu_t, V_x)) diff_V = V_t + f + mul + 0.5 * (snoise**2) * tf.linalg.trace(V_xx) # compute average L2-norm of differential operator L1 = tf.reduce_mean(tf.square(diff_V)) # Loss term #2: boundary condition # no boundary condition for this problem # Loss term #3: initial/terminal condition target_terminal = u(X_terminal) fitted_terminal = model(t_terminal, X_terminal) L3 = tf.reduce_mean(tf.square(fitted_terminal - target_terminal)) return L1, L3
def tgn_memory( n_nodes: int, memory_size: int, time_embedding_size: int, node_ids: tf.Tensor, write_idx: tf.Tensor, write_mask: tf.Tensor, write_features: tf.Tensor, write_times: tf.Tensor, ) -> TgnMemory: """Create TGN memory read & update operations. A trainable memory for nodes in an temporal interaction graph. The memory state is computed using the latest interaction event that touched a node. The update is a GRU cell, taking as input the previous memory of both source and desination nodes for that edge, the edge feature vector and time difference from interaction to current time. Note that the GRU cell is computed lazily when the memory is read, rather than when it is stored, to support a single step of truncated backpropagation through time and obtain a gradient for GRU variables. Please see "Temporal Graph Network" (https://arxiv.org/abs/2006.10637) for full details. Arguments: n_nodes -- total number of slots in the memory memory_size -- size of stored state in the memory / GRU cell output size time_embedding_size -- size of the time encoding activation provided to the GRU cell node_ids -- shape (n_read), (-1 <= ID < n_nodes), the memory locations to be read write_idx -- shape (2, n_write), (0 <= idx < n_read), the (src, dst) indices of edges, selecting nodes that should be written with their updated memory state write_mask -- shape (2, n_write), boolean tensor for elements in write_idx that should be written (true) or skipped (false), such that each memory location is written at most once write_features -- shape (n_write, feature_size), input features to be stored and used to compute the memory when it is next accessed write_times -- shape (n_write), edge event times to be stored and used to compute the memory when it next accessed Returns: TgnMemory( output -- tensor of shape (n_read, memory_size), current memory for node_ids last_update -- tensor of shape (n_read), last update of output updates -- tuple of operations to run to update the memory ) """ assert_shape(node_ids, (None, )) _, n_write = assert_shape(write_idx, (2, None)) assert_shape(write_mask, (2, n_write)) _, feature_size = assert_shape(write_features, (n_write, None)) assert_shape(write_times, (n_write, )) dtype = write_features.dtype # Declare memory # As an optimisation, we concatenate the 6 fields required by the memory # into 2 tensors, one consisting of ints, the other of floats. # This requires some extra code to slice and concat, but means we can use # 2 (dynamic) gather operations instead of 6. # Each row: [last_update, dt, neighbour] v_ints = tf.get_variable( "ints", shape=(1 + n_nodes, 3), dtype=tf.int32, trainable=False, initializer=tf.zeros_initializer(), collections=[tf.GraphKeys.GLOBAL_VARIABLES, TGN_MEMORY_VARIABLES_KEY], ) # Each row: [memory, features, direction] v_floats = tf.get_variable( "floats", shape=(1 + n_nodes, memory_size + feature_size + 2), dtype=dtype, trainable=False, initializer=tf.zeros_initializer(), collections=[tf.GraphKeys.GLOBAL_VARIABLES, TGN_MEMORY_VARIABLES_KEY], ) # Memory[0] is used for padding (node_ids == -1) safe_node_ids = 1 + node_ids # Read memory for node_ids node_ints = tf.gather(v_ints, safe_node_ids) node_last_update, node_dt, node_neighbour_idx = tf.unstack(node_ints, axis=1) node_neighbour = tf.gather(v_floats[:, :memory_size], node_neighbour_idx) node_time_encoding = time_encoder(tf.cast(node_dt, tf.float32), time_embedding_size, dtype) node_floats = tf.gather(v_floats, safe_node_ids) node_self = node_floats[:, :memory_size] node_features = node_floats[:, memory_size:memory_size + feature_size] node_direction = node_floats[:, memory_size + feature_size:] node_memory = gru_cell( node_self, tf.concat( [ node_direction[:, 0, tf.newaxis] * node_self + node_direction[:, 1, tf.newaxis] * node_neighbour, node_direction[:, 1, tf.newaxis] * node_self + node_direction[:, 0, tf.newaxis] * node_neighbour, node_features, node_time_encoding, ], axis=1, ), ) # Write memory according to (write_idx, write_mask) flat_write_idx = tf.reshape(write_idx, (-1, )) indices = tf.gather(safe_node_ids, flat_write_idx) masked_indices = indices * tf.cast(tf.reshape(write_mask, (-1, )), indices.dtype) p_last_update = tf.reshape(tf.tile(write_times[tf.newaxis], (2, 1)), (-1, )) p_dt = p_last_update - tf.gather(node_last_update, flat_write_idx) # Swap src and dst indices to get the neighbour index for each node p_neighbour = tf.roll(indices, n_write, 0) p_memory = tf.gather(node_memory, flat_write_idx) p_features = tf.tile(write_features, (2, 1)) p_direction = tf.repeat(tf.eye(2, dtype=dtype), n_write, 0) # src=[1, 0], dst=[0, 1] # There is already a data dependency, but just to be sure... with tf.control_dependencies([node_last_update, node_memory]): update_ints = v_ints.scatter_update( tf.IndexedSlices( tf.stack([p_last_update, p_dt, p_neighbour], axis=1), masked_indices)) update_floats = v_floats.scatter_update( tf.IndexedSlices( tf.concat([p_memory, p_features, p_direction], axis=1), masked_indices)) return TgnMemory( output=node_memory, last_update=node_last_update, updates=(update_ints, update_floats), )
def transformer_conv( n_output: int, n_heads: int, dropout: float, nodes: tf.Tensor, edge_idx: tf.Tensor, edges: tf.Tensor, ) -> tf.Tensor: """Implementation of Graph Transformer, https://arxiv.org/abs/2009.03509. Matches the specification of TransformerConv in PyTorch Geometric, always using a "skip" projection from inputs and shared key/value projections for edges. Arguments: n_output -- output feature size n_heads -- number of attention heads (note: head size is given by n_output/n_heads) dropout -- rate parameter for attention mask (post-softmax) dropout nodes -- shape (n_nodes, node_feature_size), input features for each node edge_idx -- shape (2, n_edges), (0 <= edge_idx < n_nodes), the source and destination of each edge, indexing into nodes edges -- shape (n_edges, edge_feature_size), input features for each edge Returns: tensor of shape (n_nodes, n_output), node features after applying a graph transformer (attention) layer """ assert n_output % n_heads == 0, \ "graph transformer output size should be divisible by the number of heads" head_size = n_output // n_heads n_nodes, _ = assert_shape(nodes, (None, None)) _, n_edges = assert_shape(edge_idx, (2, None)) assert_shape(edges, (n_edges, None)) with tf.variable_scope("skip"): skip = linear(nodes, n_output) with tf.variable_scope("edge_shared_kv"): edge_kv = linear(edges, n_output, use_bias=False) with tf.variable_scope("node_qkv"): node_qkv = linear(nodes, 3 * n_output) with tf.variable_scope("attention"): q = tf.gather(node_qkv[:, :n_output], edge_idx[1]) kv = tf.reshape( tf.gather(node_qkv[:, n_output:], edge_idx[0]), (n_edges, 2, n_output), ) k, v = tf.unstack(kv + edge_kv[:, tf.newaxis, :], axis=1) a = tf.reduce_sum(tf.reshape(q * k, (n_edges, n_heads, head_size)), -1) / (head_size**0.5) a = index_softmax(a, edge_idx[1], n_nodes) if dropout: a = tf.nn.dropout(a, rate=dropout) attention = tf.unsorted_segment_sum( tf.repeat(a, head_size, axis=1) * v, edge_idx[1], n_nodes) return skip + attention
def create_bucketed_attention_layer(input_mask, input_header, bucket_size, header_size, token_type_ids, sort_after_projection): """Returns a drop-in replacement for attention_layer using sparsse attention. Args: input_mask: int32<batch_size, seq_length> The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. input_header: bool<batch_size, seq_length> The values should be 1 or 0. Attention will not be restricted to or from tokens where header is 1. bucket_size: int. Size of sections where self attention happens. header_size: Size of the first bucket that will attend to/from everything. If None is passed will use the same as `bucket_size`. token_type_ids: List<(int, bool, <int32>[batch_size, seq_length])> contains the number of heads for each token type, whether they are sorted, and the ids of each position. Attention is restricted between tokens with the same type id and this field is used to sort/bucket. ids must be non-negative. sort_after_projection: Sorting can happen on the layer input or after applying the projection to keys, queries and values. Depending on the accelerator, one option could be more convenient. Returns: Function with same signature as `attention_layer`. See `_bucketed_attention_layer`. """ type_counts = [cnt for cnt, _, _ in token_type_ids] num_heads = sum(type_counts) # Ensure that the padding tokens get sorted last. additive_mask = _additive_mask(input_mask) type_sorted = [is_sorted for _, is_sorted, _ in token_type_ids] type_ids_masked = [ids - additive_mask for _, _, ids in token_type_ids] type_ids_argsorted = [ None if is_sorted else tf.argsort(ids, stable=True) for is_sorted, ids in zip(type_sorted, type_ids_masked) ] def _gather_and_repeat(tensors, name): """Reorder each tensor using its type_id and copies it once per head.""" if len(tensors) == 1: tensors = tensors * len(type_ids_argsorted) with tf.variable_scope(name): gathered_tensors = [] for tensor, indices in zip(tensors, type_ids_argsorted): gathered_tensors.append(tensor if indices is None else tf. gather(tensor, indices, batch_dims=1)) return tf.repeat(tf.stack(gathered_tensors, axis=1), repeats=type_counts, axis=1) # <int32>[batch_size, num_types, seq_length] type_ids = tf.stack(type_ids_masked, axis=1) type_order = tf.argsort(type_ids, stable=True) type_rank = tf.argsort(type_order, stable=True) # For each head the correct order to sort the embeddings. # <int32>[batch_size, num_attention_heads, seq_length]. type_order_repeated = tf.repeat(type_order, repeats=type_counts, axis=1) # For each head the inverse of the correct order to sort the embeddings. # <int32>[batch_size, num_attention_heads, seq_length]. type_rank_repeated = tf.repeat(type_rank, repeats=type_counts, axis=1) input_header = input_header & tf.math.equal(input_mask, 1) if sort_after_projection: # No need to sort in this case since this will happen in the BucketedTensor. type_ids_repeated = tf.repeat(type_ids, repeats=type_counts, axis=1) attend_always_repeated = tf.repeat(tf.expand_dims(input_header, axis=1), repeats=num_heads, axis=1) else: type_ids_repeated = _gather_and_repeat(type_ids_masked, 'type_id_gather') attend_always_repeated = _gather_and_repeat([input_header], 'attend_always_gather') if header_size is None: header_size = bucket_size args = { 'order_indices': type_order_repeated, 'bucket_size': bucket_size, 'header_size': header_size, 'num_heads': num_heads, 'sort_after_projection': sort_after_projection, } ids = _create_bucketed_tensor(type_ids_repeated, 'type_id', **args) attend_always = _create_bucketed_tensor(attend_always_repeated, 'attend_always', **args) head_full_mask = _compute_bucketed_attention_mask('head', 'full', ids, attend_always) tail_head_mask = _compute_bucketed_attention_mask('tail', 'head', ids, attend_always) tail_window_mask = _compute_bucketed_attention_mask( 'tail', 'window', ids, attend_always) return functools.partial(_bucketed_attention_layer, sort_after_projection=sort_after_projection, gather_and_repeat=_gather_and_repeat, bucketed_tensor_args=args, header_size=header_size, type_ranks=type_rank_repeated, head_full_mask=head_full_mask, tail_head_mask=tail_head_mask, tail_window_mask=tail_window_mask)
def make_fixed_block_side_inputs( input_mask: tf.Tensor, num_tokens_per_block: int, local_radius: int, relative_pos_max_distance: int, use_hard_g2l_mask: bool = False, use_hard_l2g_mask: bool = False, global_token_id: int = 1, name: Optional[Text] = None ) -> Tuple[GlobalLocalTransformerSideInputs, tf.Tensor]: """Utility for creating side inputs in a "fixed blocks" pattern. The "fixed blocks" experiments for NQ and OpenKP are implemented via example generation rather than using this function, but we include this function to illustrate how side inputs can be generated given just a BERT-style `input_mask` feature. The corresponding global tokens are generated as part of this function too, so no global features are required as input. Args: input_mask: <int32>[batch_size, long_seq_len] Tensor of 1 and 0 values, with 1 for actual tokens and 0 for padding. This is the same format as original BERT. `long_seq_len` must be statically known. num_tokens_per_block: Positive integer number of long tokens to assign to each global token. For pre-training on the original BERT data (which was also used for ETC pre-training), the dataset implied a value of about 27, but values like 16 or 32 would also be reasonable. local_radius: How many tokens to the left/right for input tokens to locally self-attend to. For example, a value of 1 would allow each token to only attend to 1 token to the left and 1 token to the right of it. relative_pos_max_distance: Maximum distance to use for relative position representations. All larger distances will be clipped to this value. Use 0 to skip relative position representations entirely. use_hard_g2l_mask: If True, global tokens only attend to tokens of their corresponding block in the long input. If False, global tokens attend to all non-padding long tokens. False is the default setup. use_hard_l2g_mask: If True, long tokens only attend to the global token corresponding to their block. If False, long tokens attend to all the non-padding global tokens. False is the default setup. global_token_id: Integer id to use for global tokens. The default is `1`, which was the value used during ETC pre-training. name: A name for the operation (optional). Returns: A tuple with the following 2 elements: side_inputs: A `GlobalLocalTransformerSideInputs` object containing all side input tensors. global_token_ids: <int32>[batch_size, global_seq_len] Tensor of global tokens ids suitable to pass into `EtcModel`. All global tokens will use the same `global_token_id`, except for padding tokens. """ if num_tokens_per_block <= 0: raise ValueError('`num_tokens_per_block` must be positive.') with tf.name_scope(name or 'make_fixed_block_side_inputs'): input_mask = tf.convert_to_tensor(input_mask) batch_size = tensor_utils.get_shape_list(input_mask)[0] long_seq_len = input_mask.shape.as_list()[1] if long_seq_len is None: raise ValueError('`long_seq_len` must be statically known.') global_seq_len = (long_seq_len + num_tokens_per_block - 1) // num_tokens_per_block # [batch_size, global_seq_len, num_tokens_per_block] blocked_input_mask = tensor_utils.split_into_blocks( input_mask, block_len=num_tokens_per_block, axis=-1) assert blocked_input_mask.shape.as_list()[1] == global_seq_len # [batch_size, global_seq_len] global_input_mask = tf.minimum( tf.reduce_max(blocked_input_mask, axis=-1), 1) # [long_seq_len] sentence_ids = tf.repeat(tf.range(global_seq_len, dtype=tf.int32), num_tokens_per_block)[:long_seq_len] # [batch_size, long_seq_len] sentence_ids = tf.broadcast_to(sentence_ids, [batch_size, long_seq_len]) side_inputs = make_global_local_transformer_side_inputs_from_example_ids( long_example_ids=input_mask, global_example_ids=global_input_mask, sentence_ids=sentence_ids, local_radius=local_radius, relative_pos_max_distance=relative_pos_max_distance, use_hard_g2l_mask=use_hard_g2l_mask, use_hard_l2g_mask=use_hard_l2g_mask) global_token_ids = global_token_id * global_input_mask return side_inputs, global_token_ids
def compute_headwise_sparse_attention_mask(num_row_heads, num_column_heads, bucket_size, header_size, segment_ids, column_ids, row_ids, input_mask, **_): """Computes 4D attention matrix, that varies by head. Position i will attent to position j on head h: j has a 1 in the input mask and at least one of the following is true: * i and j are in the same row and h < num_row_heads * i and j are in the same columns and h >= num_row_heads * j has segment id 0 (e.g. is part of the question) * i has segment id 0 (e.g. is part of the question) If using bucket_size we additionally impose the restriction that when sorting by the corresponding attribute (columns or rows) and splitting input in equal size buckets, only tokens on consecutive buckets or the first bucket attend to/from each other. Args: num_row_heads: <int32> Number of heads that attend within a row num_column_heads: <int32> Number of heads that attend within a column bucket_size: <int32> Only attend to position that fall in consecutive equally sized buckets, or to/from the first bucket. header_size: Optional<int32> The size of the first bucket. Will use `bucket_size` if None is passed. segment_ids: <int32>[batch_size, seq_length] column_ids: <int32>[batch_size, seq_length] row_ids: <int32>[batch_size, seq_length] input_mask: <int32>[batch_size, seq_length] Returns: attention_mask: <float32>[batch_size, num_heads, seq_length, seq_length] """ # <bool>[batch_size, seq_length] segment_zero = tf.math.equal(segment_ids, 0) # <bool>[batch_size, seq_length, seq_length] to_or_from_segument_zero = (tf.expand_dims(segment_zero, axis=2) | tf.expand_dims(segment_zero, axis=1)) # <bool>[batch_size, 1, seq_length] mask_one = tf.expand_dims(tf.math.equal(input_mask, 1), axis=1) if header_size is None: header_size = bucket_size row_bucket_mask = _compute_cross_bucket_attention_mask( bucket_size, header_size, row_ids, input_mask) column_bucket_mask = _compute_cross_bucket_attention_mask( bucket_size, header_size, column_ids, input_mask) # <bool>[batch_size, seq_length, seq_length] row_wise = (mask_one & row_bucket_mask & (_matches_token_type_id(row_ids) | to_or_from_segument_zero)) # <bool>[batch_size, seq_length, seq_length] column_wise = ( mask_one & column_bucket_mask & (_matches_token_type_id(column_ids) | to_or_from_segument_zero)) # <bool>[batch_size, 2, seq_length, seq_length] rows_and_columns = tf.stack([row_wise, column_wise], axis=1) # Repeat the row-wise and column-wise attention the correct number of times # <bool>[batch_size, num_row_heads + num_column_heads, seq_length, seq_length] result = tf.repeat(rows_and_columns, repeats=[num_row_heads, num_column_heads], axis=1) return tf.cast(result, tf.float32)
image_classes, image_confidences = expand_labels(self._ancestors_lut, 1.0) new_image_classes, new_image_confidences = expand_labels( self._descendants_lut, 0.0) return new_image_classes, new_image_confidences def _expansion_box_field_labels(self, object_classes, object_field, copy_class_id=False): """Expand the labels of a specific object field according to the hierarchy. Args: object_classes: Int64 tensor with the class id for each element in object_field. object_field: Tensor to be expanded. copy_class_id: Boolean to choose whether to use class id values in the output tensor instead of replicating the original values. Returns: A tensor with the result of expanding object_field. """ expanded_indices = tf.gather( self._ancestors_lut, object_classes - _LABEL_OFFSET, axis=0) if copy_class_id: new_object_field = tf.where(expanded_indices > 0)[:, 1] + _LABEL_OFFSET else: new_object_field = tf.repeat( object_field, tf.reduce_sum(expanded_indices, axis=1), axis=0) return new_object_field
def mask2caffe(tanh): im = sigm2image(tanh) im = tf.repeat(im, repeats=3, axis=3) return image2caffe_vgg(im)
def _generate_detections_tf(cls_outputs, box_outputs, anchor_boxes, indices, classes, image_id, image_scale, min_score_thresh=0.2, max_boxes_to_draw=50, soft_nms_sigma=0.0, iou_threshold=0.5, use_native_nms=True): """Generates detections with model outputs and anchors. Args: cls_outputs: a numpy array with shape [N, 1], which has the highest class scores on all feature levels. The N is the number of selected top-K total anchors on all levels. (k being MAX_DETECTION_POINTS) box_outputs: a numpy array with shape [N, 4], which stacks box regression outputs on all feature levels. The N is the number of selected top-k total anchors on all levels. (k being MAX_DETECTION_POINTS) anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all feature levels. The N is the number of selected top-k total anchors on all levels. indices: a numpy array with shape [N], which is the indices from top-k selection. classes: a numpy array with shape [N], which represents the class prediction on all selected anchors from top-k selection. image_id: an integer number to specify the image id. image_scale: a float tensor representing the scale between original image and input image for the detector. It is used to rescale detections for evaluating with the original groundtruth annotations. min_score_thresh: A float representing the threshold for deciding when to remove boxes based on score. max_boxes_to_draw: Max number of boxes to draw. soft_nms_sigma: A scalar float representing the Soft NMS sigma parameter; See Bodla et al, https://arxiv.org/abs/1704.04503). When `soft_nms_sigma=0.0` (which is default), we fall back to standard (hard) NMS. iou_threshold: A float representing the threshold for deciding whether boxes overlap too much with respect to IOU. use_native_nms: a bool that indicates whether to use native nms. Returns: detections: detection results in a tensor with each row representing [image_id, y, x, height, width, score, class] """ logging.info('Using tf version of post-processing.') anchor_boxes = tf.gather(anchor_boxes, indices) scores = tf.math.sigmoid(cls_outputs) # apply bounding box regression to anchors boxes = decode_box_outputs_tf( tf.transpose(box_outputs, [1, 0]), tf.transpose(anchor_boxes, [1, 0])) if use_native_nms: logging.info('Using native nms.') top_detection_idx, scores = tf.image.non_max_suppression_with_scores( boxes, scores, max_boxes_to_draw, iou_threshold=iou_threshold, score_threshold=min_score_thresh, soft_nms_sigma=soft_nms_sigma) boxes = tf.gather(boxes, top_detection_idx) else: logging.info('Using customized nms.') scores = tf.expand_dims(scores, axis=1) all_detections = tf.concat([boxes, scores], axis=1) top_detection_idx = nms_tf(all_detections, iou_threshold) detections = tf.gather(all_detections, top_detection_idx) scores = detections[:, 4] boxes = detections[:, :4] height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] detections = tf.stack([ tf.cast(tf.repeat(image_id, tf.size(top_detection_idx)), tf.float32), boxes[:, 0] * image_scale, boxes[:, 1] * image_scale, height * image_scale, width * image_scale, scores, tf.cast(tf.gather(classes, top_detection_idx) + 1, tf.float32) ], axis=1) return detections
def _get_indexes(seq_length, batch_size): return tf.repeat( tf.expand_dims(tf.range(seq_length), axis=0), repeats=batch_size, axis=0)