def flipflop_bwd(scores): """ Backward calculation for flipflop transitions from raw network output The backward matrix entries are: bwd[t, s] = logsumexp(scores of paths starting in state s at time t) Paths can end in any state at time T. For numerical reasons, we calculate a row-normalised backward matrix and a vector of scaling factors: bwd'[t, s] = bwd[t, s] - logsumexp(bwd[t, s]) fact[t] = logsumexp(bwd[t, s]) - logsumexp(bwd[t + 1, s]) :param scores: a [T, B, S] tensor containing a batch of B scores matrices each with T blocks and S flipflop transitions scores, where S = 2 * nbase * (nbase + 1) :returns: (bwd, fact) tensors of shape [T + 1, N, 2 * nbase] and [T + 1, B, 1] """ index = scores.device.index T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) bwd = torch.zeros((T + 1, N, 2 * nbase), dtype=scores.dtype, device=scores.device) fact = torch.zeros((T + 1, N, 1), dtype=scores.dtype, device=scores.device) with cp.cuda.Device(index): _flipflop_bwd(grid=(N, 1, 1), block=(2 * nbase, 1, 1), args=(scores.contiguous().data_ptr(), bwd.data_ptr(), fact.data_ptr(), T, N, nbase)) return bwd, fact
def errprobs_from_trans(trans, path): """Calculate error probs from (batch of) posterior trans weights and path This is done by: sum(trans posteriors for all transitions into base at path[b] at block b) p=------------------------------------------------------------------------- sum(trans posteriors for all transitions into any base at block b) errorprob = 1-p Args: trans (:torch:`Tensor`): Tensor of floats with shape (nblocks x batchsize x nstates) where nstates = 40 for 4-base models containing posterior transition weights (not logs!) path (:torch:`Tensor`): Tensor of longs with shape ((nblocks+1) x batchsize) containing flip-flop states (integers 0-7 for 4-base models). The transition that goes with trans[n,bn,:] is the one from path[n,bn] to path[n+1,bn]. Returns: :torch:`Tensor` : errorprob = tensor of floats with shape ((nblocks+1) x batchsize) containing errorprob for each element of the path, and -1.0 in row 0. Note that this doesn't matter since these probabilities are removed later on in the pipeline. The output matrix must be the same shape as the path in order to be fed into the stitching function. """ nblocks, batchsize, flip_flop_transitions = trans.shape nbases = flipflopfings.nbase_flipflop(flip_flop_transitions) # baseprobs will contain total probability for emission of each base # at each block normalised by prob of emitting any base. baseprobs = torch.zeros((nblocks, batchsize, nbases), dtype=torch.float, device=trans.device) # Calculate total probability of transition into base b at block n for destbase in range(nbases): t = transitions_into_base(destbase, nbases, device=trans.device) m = torch.zeros(flip_flop_transitions, dtype=torch.float, device=trans.device) m[t] = 1.0 baseprobs[:, :, destbase] = torch.matmul(trans, m) # Normalise baseprobs = baseprobs / (baseprobs.sum(dim=2, keepdim=True) + SMALL_VAL) # Calculate matrix p (see docstring) p = torch.empty_like(path, dtype=torch.float) # baseprobs is nblocks x batchsize x nbases, path is (nblocks+1) x # batchsize ix = path[1:].unsqueeze(2) % nbases p[1:] = torch.gather(baseprobs, 2, ix).squeeze(2) # errprob at block 0 set to -1 p[0] = 2.0 return 1.0 - p
def test_fwd_equals_global_norm(self): nbase = int(nbase_flipflop(self.weights.shape[1])) nstate = nbase + nbase init = np.zeros(nstate, dtype='f4') init[nbase:] = -50000 fwd, f2 = decodeutil.forward(self.weights, init=init) fwd_score = float(logsumexp(fwd[-1], axis=0)) print(f2, fwd_score, self.tensor_score) print(decodeutil.forward(self.weights, init=None)[1]) self.assertAlmostEqual(fwd_score, self.tensor_score, places=5)
def network(insize=1, size=512, winlen=19, stride=2, outsize=40): nbase = nbase_flipflop(outsize) return Serial([ Convolution(insize, size, winlen, stride=stride, fun=tanh), Reverse(GruMod(size, size)), GruMod(size, size), Reverse(GruMod(size, size)), GruMod(size, size), Reverse(GruMod(size, size)), GlobalNormFlipFlop(size, nbase), ])
def network(insize=1, size=256, winlen=19, stride=2, outsize=40): nbase = nbase_flipflop( outsize) #2 * nbase * (nbase + 1) -> 40 if nbase=4, but why?? return Serial([ Convolution(insize, size, winlen, stride=stride, fun=tanh), Reverse(GruMod(size, size)), GruMod(size, size), Reverse(GruMod(size, size)), GruMod(size, size), Reverse(GruMod(size, size)), GlobalNormFlipFlop(size, nbase), ])
def flipflop_bwd(scores): index = scores.device.index T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) bwd = torch.zeros( (T + 1, N, 2 * nbase), dtype=scores.dtype, device=scores.device) fact = torch.zeros((T + 1, N, 1), dtype=scores.dtype, device=scores.device) with cp.cuda.Device(index): _flipflop_bwd(grid=(N, 1, 1), block=(2 * nbase, 1, 1), args=(scores.contiguous().data_ptr(), bwd.data_ptr(), fact.data_ptr(), T, N, nbase)) return bwd, fact
def log_partition_flipflop(scores): T, N, C = scores.shape nbase = flipflopfings.nbase_flipflop(C) fwd = torch.cat([torch.zeros(N, nbase, device=scores.device, dtype=scores.dtype), torch.full((N, nbase), -LARGE_LOG_VAL, device=scores.device, dtype=scores.dtype)], 1) logZ = fwd.logsumexp(1, keepdim=True) fwd = fwd - logZ nbase = torch.tensor(nbase, device=scores.device, dtype=torch.long) for scores_t in scores.unbind(0): factors, fwd = global_norm_flipflop_step(scores_t, fwd, nbase) logZ = logZ + factors return logZ
def _flipflop_viterbi(scores): """ Find highest scoring flipflop paths for a batch of score matrices. This is an idiomatic pytorch implementation. Args: scores (:torch:`Tensor`): batch of score matrices with dimensions [T, batch size, S] where T is the number of blocks (time axis) and S is the number of distinct flipflop transitions. For 4 bases S = 40, and in general S = 2 * nbase * (nbase + 1). Returns: tuple(:torch:`Tensor`, :torch:`Tensor`, :torch:`Tensor`): fwd scores tensor, traceback tensor, flipflop path tensor """ T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) fwd = torch.zeros(T + 1, N, 2 * nbase, device=scores.device, dtype=scores.dtype) fwd[0, :, nbase:] = -LARGE_VAL traceback = torch.zeros(T, N, 2 * nbase, device=scores.device, dtype=torch.long) for t in range(T): to_flip = scores[t, :, :S - 2 * nbase].reshape((N, nbase, 2 * nbase)) fwd[t + 1, :, :nbase], traceback[t, :, :nbase] = (fwd[t].unsqueeze(1) + to_flip).max(2) fwd[t + 1, :, nbase:], tb_flop = (fwd[t] + scores[t, :, -2 * nbase:]).reshape( (N, 2, nbase)).max(1) traceback[t, :, nbase:] = nbase * tb_flop + \ torch.arange(nbase, device=traceback.device, dtype=traceback.dtype) path = torch.zeros(T + 1, N, device=scores.device, dtype=torch.long) path[T] = fwd[T].argmax(1) ix = torch.arange(N, device=traceback.device, dtype=torch.long) for t in range(T - 1, -1, -1): path[t] = traceback[t, ix, path[t + 1]] return fwd, traceback, path
def flipflop_viterbi(scores): """ Calculate the Viterbi path through flipflop scores matrix The scores are assumed to be in log space, i.e. the probability of a path is proportional to exp(sum of transition scores on path) The function returns the (Viterbi) forward matrix defined as: viterbi_fwd[t, s] = score of best path to time t and state s and a traceback matrix: traceback[t, s] = previous state on best path to time t and state s and a vector encoding the sequence of states on the best path. Args: scores: a [T, B, S] tensor containing a batch of B scores matrices each with T blocks and S flipflop transitions scores, where S = 2 * nbase * (nbase + 1) Returns: (fwd, traceback, best_path) tensors of shapes [T + 1, N, 2 * nbase], [T + 1, N, 2 * nbase] and [T + 1, N] """ index = scores.device.index T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) scores = scores.contiguous() fwd = torch.zeros((T + 1, N, 2 * nbase), dtype=scores.dtype, device=scores.device) fwd[:1, :, nbase:] = -LARGE_VAL traceback = torch.zeros((T + 1, N, 2 * nbase), dtype=torch.long, device=scores.device) best_path = torch.zeros((T + 1, N), dtype=torch.long, device=scores.device) with cp.cuda.Device(index): _flipflop_viterbi(grid=(N, 1, 1), block=(nbase, 1, 1), args=(scores.data_ptr(), fwd.data_ptr(), traceback.data_ptr(), best_path.data_ptr(), T, N, nbase)) return fwd, traceback, best_path
def flipflop_make_trans(scores): index = scores.device.index T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) fwd, fwd_fact = flipflop_fwd(scores) bwd, bwd_fact = flipflop_bwd(scores) scores = scores.contiguous() trans = torch.zeros_like(scores) kernel_args = ( scores.contiguous().data_ptr(), fwd.data_ptr(), bwd.data_ptr(), trans.data_ptr(), T, N, nbase, ) with cp.cuda.Device(index): _flipflop_make_trans(grid=(N,), block=(2 * nbase,), args=kernel_args) return trans, fwd_fact, bwd_fact
def flipflop_fwd(scores): """ Forward calculation for flipflop transitions from raw network output The forward matrix entries are: fwd[t, s] = logsumexp(scores of paths ending in state s at time t) Paths must start in a flip state at time 0. For numerical reasons, we calculate a row-normalised forward matrix and a vector of scaling factors: fwd'[t, s] = fwd[t, s] - logsumexp(fwd[t, s]) fact[t] = logsumexp(fwd[t, s]) - logsumexp(fwd[t - 1, s]) Args: scores: a [T, B, S] tensor containing a batch of B scores matrices each with T blocks and S flipflop transitions scores, where S = 2 * nbase * (nbase + 1) Returns: (fwd, fact) tensors of shape [T + 1, N, 2 * nbase] and [T + 1, B, 1] """ index = scores.device.index T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) fwd = torch.zeros((T + 1, N, 2 * nbase), dtype=scores.dtype, device=scores.device) fwd[0, :, :nbase] = 0.0 fwd[0, :, nbase:] = -LARGE_VAL fact = torch.zeros((T + 1, N, 1), dtype=scores.dtype, device=scores.device) with cp.cuda.Device(index): _flipflop_fwd(grid=(N, 1, 1), block=(nbase, 1, 1), args=(scores.contiguous().data_ptr(), fwd.data_ptr(), fact.data_ptr(), T, N, nbase)) return fwd, fact
def global_norm_flipflop(scores): T, N, C = scores.shape nbase = flipflopfings.nbase_flipflop(C) def step(scores_t, fwd_t): curr_scores = fwd_t.unsqueeze(1) + scores_t.reshape( (-1, nbase + 1, 2 * nbase)) base1_state = curr_scores[:, :nbase].logsumexp(2) base2_state = logaddexp(curr_scores[:, nbase, :nbase], curr_scores[:, nbase, nbase:]) new_state = torch.cat([base1_state, base2_state], dim=1) factors = new_state.logsumexp(1, keepdim=True) new_state = new_state - factors return factors, new_state fwd = scores.new_zeros((N, 2 * nbase)) logZ = fwd.logsumexp(1, keepdim=True) fwd = fwd - logZ for scores_t in scores: factors, fwd = step(scores_t, fwd) logZ = logZ + factors return scores - logZ / T
def flipflop_viterbi(scores): index = scores.device.index T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) scores = scores.contiguous() fwd = torch.zeros((T + 1, N, 2 * nbase), dtype=scores.dtype, device=scores.device) traceback = torch.zeros((T + 1, N, 2 * nbase), dtype=torch.long, device=scores.device) best_path = torch.zeros((T + 1, N), dtype=torch.long, device=scores.device) with cp.cuda.Device(index): _flipflop_viterbi( grid=(N, 1, 1), block=(nbase, 1, 1), args=( scores.data_ptr(), fwd.data_ptr(), traceback.data_ptr(), best_path.data_ptr(), T, N, nbase ) ) return fwd, traceback, best_path
def flipflop_make_trans(scores): """ Calculates posterior transition probabilities from flipflop scores The posterior transition probabilities matrix is defined as: trans[t, uv] = probability of paths that use transition uv at time t Paths must start in a flip state at time 0. Args: scores: a [T, B, S] tensor containing a batch of B scores matrices each with T blocks and S flipflop transitions scores, where S = 2 * nbase * (nbase + 1) Returns: (trans, fwd_fact, bwd_fact) tensors of shape [T, N, S], [T + 1, B] and [T + 1, B, 1]. fwd_fact and bwd_fact are the scaling factors returned by flipflop_fwd and flopflop_bwd """ index = scores.device.index T, N, S = scores.shape nbase = flipflopfings.nbase_flipflop(S) fwd, fwd_fact = flipflop_fwd(scores) bwd, bwd_fact = flipflop_bwd(scores) scores = scores.contiguous() trans = torch.zeros_like(scores) kernel_args = ( scores.contiguous().data_ptr(), fwd.data_ptr(), bwd.data_ptr(), trans.data_ptr(), T, N, nbase, ) with cp.cuda.Device(index): _flipflop_make_trans(grid=(N, ), block=(2 * nbase, ), args=kernel_args) return trans, fwd_fact, bwd_fact
def parse_sublayer(sublayer): # TODO apply additional attributes (e.g. has_bias, convolutional padding) if sublayer['type'] == 'convolution': if sublayer['activation'] != 'tanh': sys.stderr.write(( 'Incompatible convolutional layer activation fucntion ' + '({}) encountered.\n').format(sublayer['type'])) sys.exit(1) sys.stderr.write(( 'Loading convolutional layer with attributes:\n\tin size: {}\n' + '\tout size: {}\n\twinlen: {}\n\tstride: {}\n').format( sublayer['insize'], sublayer['size'], sublayer['winlen'], sublayer['stride'])) layer = Convolution( sublayer['insize'], sublayer['size'], sublayer['winlen'], stride=sublayer['stride'], fun=tanh) elif sublayer['type'] == 'LSTM': sys.stderr.write(( 'Loading LSTM layer with attributes:\n\tin size: {}\n' + '\tout size: {}\n').format( sublayer['insize'], sublayer['size'])) layer = Lstm(sublayer['insize'], sublayer['size']) elif sublayer['type'] == 'GruMod': sys.stderr.write(( 'Loading GRU layer with attributes:\n\tin size: {}\n' + '\tout size: {}\n').format( sublayer['insize'], sublayer['size'])) layer = GruMod(sublayer['insize'], sublayer['size']) elif sublayer['type'] == 'reverse': sublayer = sublayer['sublayers'] if sublayer['type'] == 'GruMod': sys.stderr.write(( 'Loading Reverse GRU layer with attributes:\n\tin size: {}\n' + '\tout size: {}\n').format( sublayer['insize'], sublayer['size'])) layer = Reverse(GruMod(sublayer['insize'], sublayer['size'])) elif sublayer['type'] == 'LSTM': sys.stderr.write(( 'Loading Reverse LSTM layer with attributes:\n' + '\tin size: {}\n\tout size: {}\n').format( sublayer['insize'], sublayer['size'])) layer = Reverse(Lstm(sublayer['insize'], sublayer['size'])) else: sys.stderr.write(( 'Invalid reversed-time layer type ({})\n').format( sublayer['type'])) sys.exit(1) elif sublayer['type'] == 'GlobalNormTwoState': nbase = nbase_flipflop(sublayer['size']) sys.stderr.write(( 'Loading flip-flop layer with attributes:\n\tin size: {}\n' + '\tnbases: {}\n').format(sublayer['insize'], nbase)) layer = GlobalNormFlipFlop(sublayer['insize'], nbase) elif sublayer['type'] == 'GlobalNormTwoStateCatMod': output_alphabet = sublayer['output_alphabet'] curr_can_base = 0 collapse_alphabet = '' for can_i_nmod in sublayer['can_nmods']: collapse_alphabet += output_alphabet[curr_can_base] * ( can_i_nmod + 1) curr_can_base += can_i_nmod + 1 alphabet_info = alphabet.AlphabetInfo( output_alphabet, collapse_alphabet, sublayer['modified_base_long_names'], do_reorder=False) sys.stderr.write(( 'Loading modified bases flip-flop layer with attributes:\n' + '\tin size: {}\n\tmod bases: {}\n').format( sublayer['insize'], alphabet_info.mod_long_names)) layer = GlobalNormFlipFlopCatMod(sublayer['insize'], alphabet_info) else: sys.stderr.write('Encountered invalid layer type ({}).\n'.format( sublayer['type'])) sys.exit(1) layer = set_params(layer, sublayer['params'], sublayer['type']) return layer
def test_bwd_equals_global_norm(self): nbase = int(nbase_flipflop(self.weights.shape[1])) bwd, _ = decodeutil.backward(self.weights) bwd_score = float(logsumexp(bwd[0, :nbase], axis=0)) self.assertAlmostEqual(bwd_score, self.tensor_score, places=5)