def hard_sw_affine( weights, tol = 1e-6, ): """Solves the Smith-Waterman LP, computing both optimal scores and alignments. Args: weights: A tf.Tensor<float>[batch, len1, len2, 9] (len1 <= len2) of edge weights (see function alignment.weights_from_sim_mat for an in-depth description). tol: A small positive constant to ensure the first transition begins at the start state. Note(fllinares): this might not be needed anymore, test! Returns: Two tensors corresponding to the scores and alignments, respectively. + The first tf.Tensor<float>[batch] contains the Smith-Waterman scores for each pair of sequences in the batch. + The second tf.Tensor<int>[batch, len1, len2, 9] contains binary entries indicating the trajectory of the indices along the optimal path for each sequence pair, by having a one along the taken edges, with nine possible edges for each i,j. """ # Gathers shape and type variables. b, l1, l2 = tf.shape(weights)[0], weights.shape[1], weights.shape[2] padded_len = l1 + l2 - 1 dtype = weights.dtype inf = alignment.large_compatible_positive(dtype) # Rearranges input tensor for vectorized wavefront iterations. weights = wavefrontify(weights) # [padded_len, s, l1, b] w_m, w_x, w_y = tf.split(weights, [4, 2, 3], axis=1) ### FORWARD # Auxiliary functions + syntatic sugar. def slice_lead_dims( t, k, s, ): """Returns t[k][:s] for "wavefrontified" tensors.""" return tf.squeeze(tf.slice(t, [k, 0, 0, 0], [1, s, l1, b]), 0) # "Wavefrontified" tensors contain invalid entries that need to be masked. def slice_inv_mask(k): """Masks invalid and sentinel entries in wavefrontified tensors.""" j_range = k - tf.range(1, l1 + 1, dtype=tf.int32) + 2 return tf.logical_and(j_range > 0, j_range <= l2) # True iff valid. # Setups reduction operators. def reduce_max_with_argmax( t, axis = 0): # Note(fllinares): I haven't yet managed to beat the performance of this # (wasteful) implementation with tf.argmax + tf.gather / tf.gather_nd :( t_max = tf.reduce_max(t, axis=axis) t_argmax = tf.argmax(t, axis=axis, output_type=tf.int32) return t_max, t_argmax # Initializes forward recursion. v_p2, v_p1 = tf.fill([3, l1, b], -inf), tf.fill([3, l1, b], -inf) # Ensures that edges cases for which all substitution costs are negative # result in a score of zero and an empty alignment. v_opt = tf.zeros(b, dtype=dtype) k_opt, i_opt = -tf.ones(b, dtype=tf.int32), -tf.ones(b, dtype=tf.int32) d_all = tf.TensorArray(tf.int32, size=padded_len, clear_after_read=True) # Runs forward Smith-Waterman recursion. for k in range(padded_len): # NOTE(fllinares): shape information along the batch dimension seems to get # lost in the edge-case b=1 tf.autograph.experimental.set_loop_options( shape_invariants=[(v_p2, tf.TensorShape([3, None, None])), (v_p1, tf.TensorShape([3, None, None])), (v_opt, tf.TensorShape([None,])), (k_opt, tf.TensorShape([None,])), (i_opt, tf.TensorShape([None,]))]) # inv_mask: masks out invalid entries for v_p2, v_p1 and v_opt updates. inv_mask_k = slice_inv_mask(k)[tf.newaxis, :, tf.newaxis] o_m = slice_lead_dims(w_m, k, 4) + alignment.top_pad(v_p2, tol) o_x = slice_lead_dims(w_x, k, 2) + v_p1[:2] v_p1 = alignment.left_pad(v_p1[:, :-1], -inf) o_y = slice_lead_dims(w_y, k, 3) + v_p1 v_m, d_m = reduce_max_with_argmax(o_m, axis=0) v_x, d_x = reduce_max_with_argmax(o_x, axis=0) v_y, d_y = reduce_max_with_argmax(o_y, axis=0) v = tf.where(inv_mask_k, tf.stack([v_m, v_x, v_y]), -inf) d = tf.stack([d_m, d_x + 1, d_y + 1]) # Accounts for start state (0). v_p2, v_p1 = v_p1, v v_opt_k, i_opt_k = reduce_max_with_argmax(v[0], axis=0) update_cond = v_opt_k > v_opt v_opt = tf.where(update_cond, v_opt_k, v_opt) k_opt = tf.where(update_cond, k, k_opt) i_opt = tf.where(update_cond, i_opt_k, i_opt) d_all = d_all.write(k, d) ### BACKTRACKING # Creates auxiliary tensors to encode backtracking "actions". steps_k = tf.convert_to_tensor([0, -2, -1, -1], dtype=tf.int32) steps_i = tf.convert_to_tensor([0, -1, 0, -1], dtype=tf.int32) trans_enc = tf.constant([[10, 10, 10, 10], [1, 2, 3, 4], [10, 5, 6, 10], [10, 7, 8, 9]], dtype=tf.int32) # [m_curr, m_prev] samp_idx = tf.range(b, dtype=tf.int32) # Initializes additional backtracking variables. m_opt = tf.ones(b, dtype=tf.int32) # Init at match states (by definition). paths_sp = tf.TensorArray(tf.int32, size=padded_len, clear_after_read=True) # Runs Smith-Waterman backtracking. for k in range(padded_len - 1, -1, -1): # NOTE(fllinares): shape information along the batch dimension seems to get # lost in the edge-case b=1 tf.autograph.experimental.set_loop_options( shape_invariants=[(m_opt, tf.TensorShape([None,]))]) # Computes tentative next indices for each alignment. k_opt_n = k_opt + tf.gather(steps_k, m_opt) i_opt_n = i_opt + tf.gather(steps_i, m_opt) # Computes tentative next state types for each alignment. m_opt_n_idx = tf.stack( [tf.maximum(m_opt - 1, 0), tf.maximum(i_opt, 0), samp_idx], -1) m_opt_n = tf.gather_nd(d_all.read(k), m_opt_n_idx) # Computes tentative next sparse updates for paths tensor. edges_n = tf.gather_nd(trans_enc, tf.stack([m_opt, m_opt_n], -1)) paths_sp_n = tf.stack([samp_idx, i_opt + 1, k_opt - i_opt + 1, edges_n], -1) # Indicates alignments to be updated in this iteration. cond = tf.logical_and(k_opt == k, m_opt != 0) # Conditionally applies updates for each alignment. k_opt = tf.where(cond, k_opt_n, k_opt) i_opt = tf.where(cond, i_opt_n, i_opt) m_opt = tf.where(cond, m_opt_n, m_opt) paths_sp_k = tf.where(cond[:, None], paths_sp_n, tf.zeros([b, 4], tf.int32)) paths_sp = paths_sp.write(k, paths_sp_k) # [0, 0, 0, 0] used as dummy upd. # Applies sparse updates, building paths tensor. paths_sp = tf.reshape(paths_sp.stack(), [-1, 4]) # [(padded_len * b), 4] paths_sp_idx, paths_sp_upd = paths_sp[:, :3], paths_sp[:, 3] paths = tf.scatter_nd(paths_sp_idx, paths_sp_upd, [b, l1 + 1, l2 + 1]) paths = paths[:, 1:, 1:] # Removes sentinel row/col. # Represents paths tensor using one-hot encoding over 9 states. paths = tf.one_hot(paths, tf.reduce_max(trans_enc))[:, :, :, 1:] return v_opt, paths
def forward(sim_mat, gap_open, gap_extend): # Gathers shape and type variables. b, l1, l2 = tf.shape(sim_mat)[0], tf.shape(sim_mat)[1], tf.shape(sim_mat)[2] padded_len = l1 + l2 - 1 go_shape, ge_shape = gap_open.shape, gap_extend.shape dtype = sim_mat.dtype inf = alignment.large_compatible_positive(dtype) # Rearranges input tensor for vectorized wavefront iterations. def slice_lead_dim(t, k): """Returns t[k] for "wavefrontified" tensors.""" return tf.squeeze(tf.slice(t, [k, 0, 0, 0], [1, 1, l1, b]), 0) # sim_mat ends with shape [l1+l2-1, 1, l1, b]. sim_mat = wavefrontify( alignment.broadcast_to_shape(sim_mat, [b, l1, l2, 1])) def slice_sim_mat(k): return slice_lead_dim(sim_mat, k) # [1, l1, b] # gap_open, gap_extend end with shape # - [l1+l2-1, 1, l1, b] if they are rank 3, # - [1, 1, b] if they are rank 1, # - [1, 1, 1] if they are rank 0. go_shape.assert_same_rank(ge_shape) # TODO(fllinares): lift the constraint. if go_shape.rank == 0 or go_shape.rank == 1: gap_open = alignment.broadcast_to_rank(gap_open, rank=2, axis=0) gap_extend = alignment.broadcast_to_rank(gap_extend, rank=2, axis=0) gap_pen = tf.stack([gap_open, gap_open, gap_extend], axis=0) slice_gap_pen = lambda k: gap_pen else: gap_open = wavefrontify( alignment.broadcast_to_shape(gap_open, [b, l1, l2, 1])) gap_extend = wavefrontify( alignment.broadcast_to_shape(gap_extend, [b, l1, l2, 1])) def slice_gap_pen(k): gap_open_k = slice_lead_dim(gap_open, k) # [1, l1, b] gap_extend_k = slice_lead_dim(gap_extend, k) # [1, l1, b] return tf.concat([gap_open_k, gap_open_k, gap_extend_k], 0) # [3,l1,b] # "Wavefrontified" tensors contain invalid entries that need to be masked. def slice_inv_mask(k): """Masks invalid and sentinel entries in wavefrontified tensors.""" j_range = k - tf.range(1, l1 + 1, dtype=tf.int32) + 2 return tf.logical_and(j_range > 0, j_range <= l2) # True iff valid. # Sets up reduction operators. # TODO(fllinares): temp = 0 / None case. maxop = lambda t: temp * tf.reduce_logsumexp(t / temp, 0, True) argmaxop = lambda t: tf.nn.softmax(t / temp, 0) endop = lambda t: temp * tf.reduce_logsumexp(t / temp, [0, 1], True) # Initializes forward recursion. v_p2, v_p1 = tf.fill([3, l1, b], -inf), tf.fill([3, l1, b], -inf) v_all = tf.TensorArray(dtype, size=padded_len, clear_after_read=False) # Runs forward Smith-Waterman recursion. for k in tf.range(padded_len): # NOTE(fllinares): shape information along the batch dimension seems to # get lost in the edge-case b=1 tf.autograph.experimental.set_loop_options( shape_invariants=[(v_p2, tf.TensorShape([3, None, None])), (v_p1, tf.TensorShape([3, None, None]))]) inv_mask_k = slice_inv_mask(k)[tf.newaxis, :, tf.newaxis] sim_mat_k, gap_pen_k = slice_sim_mat(k), slice_gap_pen(k) o_m = alignment.top_pad(v_p2, 0.0) # [4, l1, b] o_x = v_p1[:2] - gap_pen_k[1:] # [2, l1, b] v_p1 = alignment.left_pad(v_p1[:, :-1], -inf) # [3, l1, b] o_y = v_p1 - gap_pen_k # [3, l1, b] v = tf.concat([sim_mat_k + maxop(o_m), maxop(o_x), maxop(o_y)], 0) v = tf.where(inv_mask_k, v, -inf) # [3, l1, b] v_p2, v_p1 = v_p1, v v_all = v_all.write(k, v) v_opt = endop(v_all.stack()[:, 0]) def grad(dy): # NOTE(fllinares): we reuse value buffers closed over to store grads. nonlocal v_all def unsqueeze_lead_dim( t, i): """Returns tf.expand_dims(t[i], 0) for tf.Tensor `t`.""" return tf.slice(t, [i, 0, 0], [1, l1, b]) # [1, l1, b] # Initializes backprop recursion. m_term_p2, m_term_p1 = tf.fill([3, l1, b], 0.0), tf.fill([3, l1, b], 0.0) x_term_p1, y_term_p1 = tf.fill([2, l1, b], 0.0), tf.fill([3, l1, b], 0.0) if go_shape.rank == 0: g_sm = tf.TensorArray(dtype, size=padded_len, clear_after_read=True) g_go, g_ge = tf.zeros([], dtype=dtype), tf.zeros([], dtype=dtype) elif go_shape.rank == 1: g_sm = tf.TensorArray(dtype, size=padded_len, clear_after_read=True) g_go, g_ge = tf.zeros([b], dtype=dtype), tf.zeros([b], dtype=dtype) else: # NOTE(fllinares): needed to pacify AutoGraph... g_sm, g_go, g_ge = 0.0, 0.0, 0.0 # Runs backprop Smith-Waterman recursion. for k in tf.range(padded_len - 1, -1, -1): # NOTE(fllinares): shape information along the batch dimension seems to # get lost in the edge-case b=1 tf.autograph.experimental.set_loop_options( shape_invariants=[(m_term_p2, tf.TensorShape([3, None, None])), (m_term_p1, tf.TensorShape([3, None, None])), (x_term_p1, tf.TensorShape([2, None, None])), (y_term_p1, tf.TensorShape([3, None, None]))]) # NOTE(fllinares): empirically, keeping v_m, v_x and v_y as separate # tf.TensorArrays appears slightly advantageous in TPU but rather # disadvantageous in GPU...Moreover, despite TPU performance being # improved for most inputs, certain input sizes (e.g. 1024 x 512 x 512) # lead to catastrophic (x100) runtime "spikes". Because of this, I have # decided to be conservative and keep v_m, v_x and v_y packed into a # single tensor until I understand better what's going on... v_k = v_all.read(k) v_n1 = v_all.read(k - 1) if k >= 1 else tf.fill([3, l1, b], -inf) v_n2 = v_all.read(k - 2) if k >= 2 else tf.fill([3, l1, b], -inf) gap_pen_k = slice_gap_pen(k) o_m = alignment.top_pad(v_n2, 0.0) # [4, l1, b] o_x = v_n1[:2] - gap_pen_k[1:] # [2, l1, b] o_y = alignment.left_pad(v_n1[:, :-1], -inf) - gap_pen_k # [3, l1, b] m_tilde = argmaxop(o_m)[1:, :-1] # [3, l1 - 1, b] x_tilde = argmaxop(o_x) # [2, l1, b] y_tilde = argmaxop(o_y) # [3, l1, b] # TODO(fllinares): might be able to improve numerical prec. in 1st term. m_adj = (tf.exp((unsqueeze_lead_dim(v_k, 0) - v_opt) / temp) + unsqueeze_lead_dim(m_term_p2, 0) + unsqueeze_lead_dim(y_term_p1, 0) + unsqueeze_lead_dim(x_term_p1, 0)) # [1, l1, b] x_adj = (unsqueeze_lead_dim(m_term_p2, 1) + unsqueeze_lead_dim(y_term_p1, 1) + unsqueeze_lead_dim(x_term_p1, 1)) # [1, l1, b] y_adj = (unsqueeze_lead_dim(m_term_p2, 2) + unsqueeze_lead_dim(y_term_p1, 2)) # [1, l1, b] m_term = m_adj[:, 1:] * m_tilde # [3, l1 - 1, b] x_term = x_adj * x_tilde # [2, l1, b] y_term = y_adj * y_tilde # [3, l1, b] g_sm_k = m_adj g_go_k = -(unsqueeze_lead_dim(x_term, 0) + unsqueeze_lead_dim(y_term, 0) + unsqueeze_lead_dim(y_term, 1)) # [1, l1, b] g_ge_k = -(unsqueeze_lead_dim(x_term, 1) + unsqueeze_lead_dim(y_term, 2)) # [1, l1, b] # NOTE(fllinares): empirically, avoiding unnecessary tf.TensorArray # writes for g_go and g_ge g_ge when gap penalties have rank 0 or 1 is # again advantageous in TPU, but does not seem to yield consistently # better performance in GPU. if go_shape.rank == 0: # pytype: disable=attribute-error g_sm = g_sm.write(k, g_sm_k) g_go += tf.reduce_sum(g_go_k) g_ge += tf.reduce_sum(g_ge_k) elif go_shape.rank == 1: g_sm = g_sm.write(k, g_sm_k) g_go += tf.reduce_sum(g_go_k, [0, 1]) g_ge += tf.reduce_sum(g_ge_k, [0, 1]) else: v_all = v_all.write(k, tf.concat([g_sm_k, g_go_k, g_ge_k], 0)) m_term_p2, m_term_p1 = m_term_p1, alignment.right_pad(m_term, 0.0) # NOTE(fllinares): empirically, the roll-based solution appears to # improve over right_pad(y_term[:, 1:], 0.0) in TPU while being # somewhat slower in GPU... x_term_p1, y_term_p1 = x_term, tf.roll(y_term, -1, axis=1) if go_shape.rank == 0 or go_shape.rank == 1: g_sm = tf.squeeze(unwavefrontify(g_sm.stack()), axis=-1) else: g = unwavefrontify(v_all.stack()) g_sm, g_go, g_ge = g[Ellipsis, 0], g[Ellipsis, 1], g[Ellipsis, 2] g_sm *= dy[:, tf.newaxis, tf.newaxis] if go_shape.rank == 0: dy_gap_pen = tf.reduce_sum(dy) elif go_shape.rank == 1: dy_gap_pen = dy else: dy_gap_pen = dy[:, tf.newaxis, tf.newaxis] g_go *= dy_gap_pen g_ge *= dy_gap_pen return g_sm, g_go, g_ge return v_opt[0, 0], grad
def grad(dy): # NOTE(fllinares): we reuse value buffers closed over to store grads. nonlocal v_all def unsqueeze_lead_dim( t, i): """Returns tf.expand_dims(t[i], 0) for tf.Tensor `t`.""" return tf.slice(t, [i, 0, 0], [1, l1, b]) # [1, l1, b] # Initializes backprop recursion. m_term_p2, m_term_p1 = tf.fill([3, l1, b], 0.0), tf.fill([3, l1, b], 0.0) x_term_p1, y_term_p1 = tf.fill([2, l1, b], 0.0), tf.fill([3, l1, b], 0.0) if go_shape.rank == 0: g_sm = tf.TensorArray(dtype, size=padded_len, clear_after_read=True) g_go, g_ge = tf.zeros([], dtype=dtype), tf.zeros([], dtype=dtype) elif go_shape.rank == 1: g_sm = tf.TensorArray(dtype, size=padded_len, clear_after_read=True) g_go, g_ge = tf.zeros([b], dtype=dtype), tf.zeros([b], dtype=dtype) else: # NOTE(fllinares): needed to pacify AutoGraph... g_sm, g_go, g_ge = 0.0, 0.0, 0.0 # Runs backprop Smith-Waterman recursion. for k in tf.range(padded_len - 1, -1, -1): # NOTE(fllinares): shape information along the batch dimension seems to # get lost in the edge-case b=1 tf.autograph.experimental.set_loop_options( shape_invariants=[(m_term_p2, tf.TensorShape([3, None, None])), (m_term_p1, tf.TensorShape([3, None, None])), (x_term_p1, tf.TensorShape([2, None, None])), (y_term_p1, tf.TensorShape([3, None, None]))]) # NOTE(fllinares): empirically, keeping v_m, v_x and v_y as separate # tf.TensorArrays appears slightly advantageous in TPU but rather # disadvantageous in GPU...Moreover, despite TPU performance being # improved for most inputs, certain input sizes (e.g. 1024 x 512 x 512) # lead to catastrophic (x100) runtime "spikes". Because of this, I have # decided to be conservative and keep v_m, v_x and v_y packed into a # single tensor until I understand better what's going on... v_k = v_all.read(k) v_n1 = v_all.read(k - 1) if k >= 1 else tf.fill([3, l1, b], -inf) v_n2 = v_all.read(k - 2) if k >= 2 else tf.fill([3, l1, b], -inf) gap_pen_k = slice_gap_pen(k) o_m = alignment.top_pad(v_n2, 0.0) # [4, l1, b] o_x = v_n1[:2] - gap_pen_k[1:] # [2, l1, b] o_y = alignment.left_pad(v_n1[:, :-1], -inf) - gap_pen_k # [3, l1, b] m_tilde = argmaxop(o_m)[1:, :-1] # [3, l1 - 1, b] x_tilde = argmaxop(o_x) # [2, l1, b] y_tilde = argmaxop(o_y) # [3, l1, b] # TODO(fllinares): might be able to improve numerical prec. in 1st term. m_adj = (tf.exp((unsqueeze_lead_dim(v_k, 0) - v_opt) / temp) + unsqueeze_lead_dim(m_term_p2, 0) + unsqueeze_lead_dim(y_term_p1, 0) + unsqueeze_lead_dim(x_term_p1, 0)) # [1, l1, b] x_adj = (unsqueeze_lead_dim(m_term_p2, 1) + unsqueeze_lead_dim(y_term_p1, 1) + unsqueeze_lead_dim(x_term_p1, 1)) # [1, l1, b] y_adj = (unsqueeze_lead_dim(m_term_p2, 2) + unsqueeze_lead_dim(y_term_p1, 2)) # [1, l1, b] m_term = m_adj[:, 1:] * m_tilde # [3, l1 - 1, b] x_term = x_adj * x_tilde # [2, l1, b] y_term = y_adj * y_tilde # [3, l1, b] g_sm_k = m_adj g_go_k = -(unsqueeze_lead_dim(x_term, 0) + unsqueeze_lead_dim(y_term, 0) + unsqueeze_lead_dim(y_term, 1)) # [1, l1, b] g_ge_k = -(unsqueeze_lead_dim(x_term, 1) + unsqueeze_lead_dim(y_term, 2)) # [1, l1, b] # NOTE(fllinares): empirically, avoiding unnecessary tf.TensorArray # writes for g_go and g_ge g_ge when gap penalties have rank 0 or 1 is # again advantageous in TPU, but does not seem to yield consistently # better performance in GPU. if go_shape.rank == 0: # pytype: disable=attribute-error g_sm = g_sm.write(k, g_sm_k) g_go += tf.reduce_sum(g_go_k) g_ge += tf.reduce_sum(g_ge_k) elif go_shape.rank == 1: g_sm = g_sm.write(k, g_sm_k) g_go += tf.reduce_sum(g_go_k, [0, 1]) g_ge += tf.reduce_sum(g_ge_k, [0, 1]) else: v_all = v_all.write(k, tf.concat([g_sm_k, g_go_k, g_ge_k], 0)) m_term_p2, m_term_p1 = m_term_p1, alignment.right_pad(m_term, 0.0) # NOTE(fllinares): empirically, the roll-based solution appears to # improve over right_pad(y_term[:, 1:], 0.0) in TPU while being # somewhat slower in GPU... x_term_p1, y_term_p1 = x_term, tf.roll(y_term, -1, axis=1) if go_shape.rank == 0 or go_shape.rank == 1: g_sm = tf.squeeze(unwavefrontify(g_sm.stack()), axis=-1) else: g = unwavefrontify(v_all.stack()) g_sm, g_go, g_ge = g[Ellipsis, 0], g[Ellipsis, 1], g[Ellipsis, 2] g_sm *= dy[:, tf.newaxis, tf.newaxis] if go_shape.rank == 0: dy_gap_pen = tf.reduce_sum(dy) elif go_shape.rank == 1: dy_gap_pen = dy else: dy_gap_pen = dy[:, tf.newaxis, tf.newaxis] g_go *= dy_gap_pen g_ge *= dy_gap_pen return g_sm, g_go, g_ge
def soft_sw_affine_fwd( sim_mat, gap_open, gap_extend, temp = 1.0, ): """Solves the smoothed Smith-Waterman LP, computing the softmax values only. This function provides currently the fastest and most memory efficient Smith-Waterman forward recursion in this module, but relies on autodiff for backtracking / backprop. See `smith_waterman` and `soft_sw_affine` for implementations with custom backtracking / backprop. Args: sim_mat: a tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the substitution values for pairs of sequences. gap_open: a tf.Tensor<float>[], tf.Tensor<float>[batch] or tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the penalties for opening a gap. Must agree in rank with gap_extend. gap_extend: a tf.Tensor<float>[], tf.Tensor<float>[batch] or tf.Tensor<float>[batch, len1, len2] (len1 <= len2) with the penalties for with the penalties for extending a gap. Must agree in rank with gap_open. temp: a float controlling the regularization strength. If None, the unsmoothed DP will be solved instead (i.e. equivalent to temperature = 0). Returns: A tf.Tensor<float>[batch] of softmax values computed in the forward pass. """ # Gathers shape and type variables. b, l1, l2 = tf.shape(sim_mat)[0], tf.shape(sim_mat)[1], tf.shape(sim_mat)[2] padded_len = l1 + l2 - 1 go_shape, ge_shape = gap_open.shape, gap_extend.shape dtype = sim_mat.dtype inf = alignment.large_compatible_positive(dtype) # Rearranges input tensor for vectorized wavefront iterations. def slice_lead_dim(t, k): """Returns t[k] for "wavefrontified" tensors.""" return tf.squeeze(tf.slice(t, [k, 0, 0, 0], [1, 1, l1, b]), 0) # sim_mat ends with shape [l1+l2-1, 1, l1, b]. sim_mat = wavefrontify(alignment.broadcast_to_shape(sim_mat, [b, l1, l2, 1])) def slice_sim_mat(k): return slice_lead_dim(sim_mat, k) # [1, l1, b] # gap_open, gap_extend end with shape # - [l1+l2-1, 1, l1, b] if they are rank 3, # - [1, 1, b] if they are rank 1, # - [1, 1, 1] if they are rank 0. go_shape.assert_same_rank(ge_shape) # TODO(fllinares): lift this constraint. if go_shape.rank == 0 or go_shape.rank == 1: gap_open = alignment.broadcast_to_rank(gap_open, rank=2, axis=0) gap_extend = alignment.broadcast_to_rank(gap_extend, rank=2, axis=0) gap_pen = tf.stack([gap_open, gap_open, gap_extend], axis=0) slice_gap_pen = lambda k: gap_pen else: gap_open = wavefrontify( alignment.broadcast_to_shape(gap_open, [b, l1, l2, 1])) gap_extend = wavefrontify( alignment.broadcast_to_shape(gap_extend, [b, l1, l2, 1])) def slice_gap_pen(k): gap_open_k = slice_lead_dim(gap_open, k) # [1, l1, b] gap_extend_k = slice_lead_dim(gap_extend, k) # [1, l1, b] return tf.concat([gap_open_k, gap_open_k, gap_extend_k], 0) # [3, l1, b] # "Wavefrontified" tensors contain invalid entries that need to be masked. def slice_inv_mask(k): """Masks invalid and sentinel entries in wavefrontified tensors.""" j_range = k - tf.range(1, l1 + 1, dtype=tf.int32) + 2 return tf.logical_and(j_range > 0, j_range <= l2) # True iff valid. # Sets up reduction operators. if temp is None: maxop = lambda t: tf.reduce_max(t, 0, keepdims=True) endop = lambda t: tf.reduce_max(t, [0, 1]) else: maxop = lambda t: temp * tf.reduce_logsumexp(t / temp, 0, keepdims=True) endop = lambda t: temp * tf.reduce_logsumexp(t / temp, [0, 1]) # Initializes recursion. v_p2, v_p1 = tf.fill([3, l1, b], -inf), tf.fill([3, l1, b], -inf) v_m_all = tf.TensorArray(dtype, size=padded_len, clear_after_read=False) # Runs forward Smith-Waterman recursion. for k in tf.range(padded_len): # NOTE(fllinares): shape information along the batch dimension seems to get # lost in the edge-case b=1 tf.autograph.experimental.set_loop_options( shape_invariants=[(v_p2, tf.TensorShape([3, None, None])), (v_p1, tf.TensorShape([3, None, None]))]) inv_mask_k = slice_inv_mask(k)[tf.newaxis, :, tf.newaxis] sim_mat_k, gap_pen_k = slice_sim_mat(k), slice_gap_pen(k) o_m = alignment.top_pad(v_p2, 0.0) # [4, l1, b] o_x = v_p1[:2] - gap_pen_k[1:] # [2, l1, b] v_p1 = alignment.left_pad(v_p1[:, :-1], -inf) o_y = v_p1 - gap_pen_k # [3, l1, b] v = tf.concat([sim_mat_k + maxop(o_m), maxop(o_x), maxop(o_y)], 0) v = tf.where(inv_mask_k, v, -inf) # [3, l1, b] v_p2, v_p1 = v_p1, v v_m_all = v_m_all.write(k, v[0]) return endop(v_m_all.stack())