Exemplo n.º 1
0
def flipflop_bwd(scores):
    """ Backward calculation for flipflop transitions from raw network output

    The backward matrix entries are:

        bwd[t, s] = logsumexp(scores of paths starting in state s at time t)

    Paths can end in any state at time T.

    For numerical reasons, we calculate a row-normalised backward matrix
    and a vector of scaling factors:

        bwd'[t, s] = bwd[t, s] - logsumexp(bwd[t, s])
        fact[t] = logsumexp(bwd[t, s]) - logsumexp(bwd[t + 1, s])

    :param scores: a [T, B, S] tensor containing a batch of B scores matrices
        each with T blocks and S flipflop transitions scores, where
        S = 2 * nbase * (nbase + 1)

    :returns: (bwd, fact) tensors of shape [T + 1, N, 2 * nbase] and [T + 1, B, 1]
    """
    index = scores.device.index
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    bwd = torch.zeros((T + 1, N, 2 * nbase),
                      dtype=scores.dtype,
                      device=scores.device)
    fact = torch.zeros((T + 1, N, 1), dtype=scores.dtype, device=scores.device)
    with cp.cuda.Device(index):
        _flipflop_bwd(grid=(N, 1, 1),
                      block=(2 * nbase, 1, 1),
                      args=(scores.contiguous().data_ptr(), bwd.data_ptr(),
                            fact.data_ptr(), T, N, nbase))
    return bwd, fact
Exemplo n.º 2
0
def errprobs_from_trans(trans, path):
    """Calculate error probs from (batch of) posterior trans weights and path

    This is done by:

      sum(trans posteriors for all transitions into base at path[b] at block b)
    p=-------------------------------------------------------------------------
        sum(trans posteriors for all transitions into any base at block b)

    errorprob = 1-p

    Args:
        trans (:torch:`Tensor`): Tensor of floats with shape
            (nblocks x batchsize x nstates) where nstates = 40 for 4-base
            models containing posterior transition weights (not logs!)

        path (:torch:`Tensor`): Tensor of longs with shape
            ((nblocks+1) x batchsize) containing flip-flop states (integers 0-7
            for 4-base models). The transition that goes with trans[n,bn,:] is
            the one from path[n,bn] to path[n+1,bn].

    Returns:
        :torch:`Tensor` : errorprob = tensor of floats with shape
            ((nblocks+1) x batchsize) containing errorprob for each element of
            the path, and -1.0 in row 0. Note that this doesn't matter since
            these probabilities are removed later on in the pipeline. The
            output matrix must be the same shape as the path in order to be fed
            into the stitching function.
    """
    nblocks, batchsize, flip_flop_transitions = trans.shape
    nbases = flipflopfings.nbase_flipflop(flip_flop_transitions)
    # baseprobs will contain total probability for emission of each base
    # at each block normalised by prob of emitting any base.
    baseprobs = torch.zeros((nblocks, batchsize, nbases),
                            dtype=torch.float,
                            device=trans.device)
    # Calculate total probability of transition into base b at block n
    for destbase in range(nbases):
        t = transitions_into_base(destbase, nbases, device=trans.device)
        m = torch.zeros(flip_flop_transitions,
                        dtype=torch.float,
                        device=trans.device)
        m[t] = 1.0
        baseprobs[:, :, destbase] = torch.matmul(trans, m)

    # Normalise
    baseprobs = baseprobs / (baseprobs.sum(dim=2, keepdim=True) + SMALL_VAL)

    # Calculate matrix p (see docstring)
    p = torch.empty_like(path, dtype=torch.float)
    # baseprobs is nblocks x batchsize x nbases, path is (nblocks+1) x
    # batchsize
    ix = path[1:].unsqueeze(2) % nbases
    p[1:] = torch.gather(baseprobs, 2, ix).squeeze(2)
    # errprob at block 0 set to -1
    p[0] = 2.0
    return 1.0 - p
Exemplo n.º 3
0
 def test_fwd_equals_global_norm(self):
     nbase = int(nbase_flipflop(self.weights.shape[1]))
     nstate = nbase + nbase
     init = np.zeros(nstate, dtype='f4')
     init[nbase:] = -50000
     fwd, f2 = decodeutil.forward(self.weights, init=init)
     fwd_score = float(logsumexp(fwd[-1], axis=0))
     print(f2, fwd_score, self.tensor_score)
     print(decodeutil.forward(self.weights, init=None)[1])
     self.assertAlmostEqual(fwd_score, self.tensor_score, places=5)
Exemplo n.º 4
0
def network(insize=1, size=512, winlen=19, stride=2, outsize=40):
    nbase = nbase_flipflop(outsize)

    return Serial([
        Convolution(insize, size, winlen, stride=stride, fun=tanh),
        Reverse(GruMod(size, size)),
        GruMod(size, size),
        Reverse(GruMod(size, size)),
        GruMod(size, size),
        Reverse(GruMod(size, size)),
        GlobalNormFlipFlop(size, nbase),
    ])
Exemplo n.º 5
0
def network(insize=1, size=256, winlen=19, stride=2, outsize=40):
    nbase = nbase_flipflop(
        outsize)  #2 * nbase * (nbase + 1) -> 40 if nbase=4, but why??

    return Serial([
        Convolution(insize, size, winlen, stride=stride, fun=tanh),
        Reverse(GruMod(size, size)),
        GruMod(size, size),
        Reverse(GruMod(size, size)),
        GruMod(size, size),
        Reverse(GruMod(size, size)),
        GlobalNormFlipFlop(size, nbase),
    ])
Exemplo n.º 6
0
def flipflop_bwd(scores):
    index = scores.device.index
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    bwd = torch.zeros(
        (T + 1, N, 2 * nbase), dtype=scores.dtype, device=scores.device)
    fact = torch.zeros((T + 1, N, 1), dtype=scores.dtype, device=scores.device)
    with cp.cuda.Device(index):
        _flipflop_bwd(grid=(N, 1, 1), block=(2 * nbase, 1, 1),
                      args=(scores.contiguous().data_ptr(), bwd.data_ptr(),
                            fact.data_ptr(), T, N, nbase))
    return bwd, fact
Exemplo n.º 7
0
def log_partition_flipflop(scores):
    T, N, C = scores.shape
    nbase = flipflopfings.nbase_flipflop(C)

    fwd = torch.cat([torch.zeros(N, nbase, device=scores.device,
                                 dtype=scores.dtype),
                     torch.full((N, nbase), -LARGE_LOG_VAL, device=scores.device,
                                dtype=scores.dtype)],
                    1)
    logZ = fwd.logsumexp(1, keepdim=True)
    fwd = fwd - logZ
    nbase = torch.tensor(nbase, device=scores.device, dtype=torch.long)
    for scores_t in scores.unbind(0):
        factors, fwd = global_norm_flipflop_step(scores_t, fwd, nbase)
        logZ = logZ + factors
    return logZ
Exemplo n.º 8
0
def _flipflop_viterbi(scores):
    """ Find highest scoring flipflop paths for a batch of score matrices. This
        is an idiomatic pytorch implementation.

    Args:
        scores (:torch:`Tensor`): batch of score matrices with dimensions
            [T, batch size, S] where T is the number of blocks (time axis) and
            S is the number of distinct flipflop transitions. For 4 bases
            S = 40, and in general S = 2 * nbase * (nbase + 1).

    Returns:
        tuple(:torch:`Tensor`, :torch:`Tensor`, :torch:`Tensor`):
            fwd scores tensor, traceback tensor, flipflop path tensor
    """
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    fwd = torch.zeros(T + 1,
                      N,
                      2 * nbase,
                      device=scores.device,
                      dtype=scores.dtype)
    fwd[0, :, nbase:] = -LARGE_VAL
    traceback = torch.zeros(T,
                            N,
                            2 * nbase,
                            device=scores.device,
                            dtype=torch.long)

    for t in range(T):
        to_flip = scores[t, :, :S - 2 * nbase].reshape((N, nbase, 2 * nbase))
        fwd[t + 1, :, :nbase], traceback[t, :, :nbase] = (fwd[t].unsqueeze(1) +
                                                          to_flip).max(2)
        fwd[t + 1, :,
            nbase:], tb_flop = (fwd[t] + scores[t, :, -2 * nbase:]).reshape(
                (N, 2, nbase)).max(1)
        traceback[t, :, nbase:] = nbase * tb_flop + \
            torch.arange(nbase, device=traceback.device, dtype=traceback.dtype)

    path = torch.zeros(T + 1, N, device=scores.device, dtype=torch.long)

    path[T] = fwd[T].argmax(1)
    ix = torch.arange(N, device=traceback.device, dtype=torch.long)
    for t in range(T - 1, -1, -1):
        path[t] = traceback[t, ix, path[t + 1]]

    return fwd, traceback, path
Exemplo n.º 9
0
def flipflop_viterbi(scores):
    """ Calculate the Viterbi path through flipflop scores matrix

    The scores are assumed to be in log space, i.e. the probability
    of a path is proportional to exp(sum of transition scores on path)

    The function returns the (Viterbi) forward matrix defined as:

        viterbi_fwd[t, s] = score of best path to time t and state s

    and a traceback matrix:

        traceback[t, s] = previous state on best path to time t and state s

    and a vector encoding the sequence of states on the best path.

    Args:
        scores: a [T, B, S] tensor containing a batch of B scores matrices
            each with T blocks and S flipflop transitions scores, where
            S = 2 * nbase * (nbase + 1)

    Returns:
        (fwd, traceback, best_path) tensors of shapes [T + 1, N, 2 * nbase],
            [T + 1, N, 2 * nbase] and [T + 1, N]
    """
    index = scores.device.index
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    scores = scores.contiguous()
    fwd = torch.zeros((T + 1, N, 2 * nbase),
                      dtype=scores.dtype,
                      device=scores.device)
    fwd[:1, :, nbase:] = -LARGE_VAL
    traceback = torch.zeros((T + 1, N, 2 * nbase),
                            dtype=torch.long,
                            device=scores.device)
    best_path = torch.zeros((T + 1, N), dtype=torch.long, device=scores.device)
    with cp.cuda.Device(index):
        _flipflop_viterbi(grid=(N, 1, 1),
                          block=(nbase, 1, 1),
                          args=(scores.data_ptr(), fwd.data_ptr(),
                                traceback.data_ptr(), best_path.data_ptr(), T,
                                N, nbase))
    return fwd, traceback, best_path
Exemplo n.º 10
0
def flipflop_make_trans(scores):
    index = scores.device.index
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    fwd, fwd_fact = flipflop_fwd(scores)
    bwd, bwd_fact = flipflop_bwd(scores)
    scores = scores.contiguous()
    trans = torch.zeros_like(scores)
    kernel_args = (
        scores.contiguous().data_ptr(),
        fwd.data_ptr(),
        bwd.data_ptr(),
        trans.data_ptr(),
        T, N, nbase,
    )
    with cp.cuda.Device(index):
        _flipflop_make_trans(grid=(N,), block=(2 * nbase,), args=kernel_args)
    return trans, fwd_fact, bwd_fact
Exemplo n.º 11
0
def flipflop_fwd(scores):
    """ Forward calculation for flipflop transitions from raw network output

    The forward matrix entries are:

        fwd[t, s] = logsumexp(scores of paths ending in state s at time t)

    Paths must start in a flip state at time 0.

    For numerical reasons, we calculate a row-normalised forward matrix
    and a vector of scaling factors:

        fwd'[t, s] = fwd[t, s] - logsumexp(fwd[t, s])
        fact[t] = logsumexp(fwd[t, s]) - logsumexp(fwd[t - 1, s])

    Args:
        scores: a [T, B, S] tensor containing a batch of B scores matrices
            each with T blocks and S flipflop transitions scores, where
            S = 2 * nbase * (nbase + 1)

    Returns:
        (fwd, fact) tensors of shape [T + 1, N, 2 * nbase] and [T + 1, B, 1]
    """
    index = scores.device.index
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    fwd = torch.zeros((T + 1, N, 2 * nbase),
                      dtype=scores.dtype,
                      device=scores.device)
    fwd[0, :, :nbase] = 0.0
    fwd[0, :, nbase:] = -LARGE_VAL
    fact = torch.zeros((T + 1, N, 1), dtype=scores.dtype, device=scores.device)
    with cp.cuda.Device(index):
        _flipflop_fwd(grid=(N, 1, 1),
                      block=(nbase, 1, 1),
                      args=(scores.contiguous().data_ptr(), fwd.data_ptr(),
                            fact.data_ptr(), T, N, nbase))

    return fwd, fact
Exemplo n.º 12
0
def global_norm_flipflop(scores):
    T, N, C = scores.shape
    nbase = flipflopfings.nbase_flipflop(C)

    def step(scores_t, fwd_t):
        curr_scores = fwd_t.unsqueeze(1) + scores_t.reshape(
            (-1, nbase + 1, 2 * nbase))
        base1_state = curr_scores[:, :nbase].logsumexp(2)
        base2_state = logaddexp(curr_scores[:, nbase, :nbase],
                                curr_scores[:, nbase, nbase:])
        new_state = torch.cat([base1_state, base2_state], dim=1)
        factors = new_state.logsumexp(1, keepdim=True)
        new_state = new_state - factors
        return factors, new_state

    fwd = scores.new_zeros((N, 2 * nbase))
    logZ = fwd.logsumexp(1, keepdim=True)
    fwd = fwd - logZ
    for scores_t in scores:
        factors, fwd = step(scores_t, fwd)
        logZ = logZ + factors
    return scores - logZ / T
Exemplo n.º 13
0
def flipflop_viterbi(scores):
    index = scores.device.index
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    scores = scores.contiguous()
    fwd = torch.zeros((T + 1, N, 2 * nbase), dtype=scores.dtype, device=scores.device)
    traceback = torch.zeros((T + 1, N, 2 * nbase), dtype=torch.long, device=scores.device)
    best_path = torch.zeros((T + 1, N), dtype=torch.long, device=scores.device)
    with cp.cuda.Device(index):
        _flipflop_viterbi(
            grid=(N, 1, 1),
            block=(nbase, 1, 1),
            args=(
                scores.data_ptr(),
                fwd.data_ptr(),
                traceback.data_ptr(),
                best_path.data_ptr(),
                T, N, nbase
            )
        )
    return fwd, traceback, best_path
Exemplo n.º 14
0
def flipflop_make_trans(scores):
    """ Calculates posterior transition probabilities from flipflop scores

    The posterior transition probabilities matrix is defined as:

        trans[t, uv] = probability of paths that use transition uv at time t

    Paths must start in a flip state at time 0.

    Args:
        scores: a [T, B, S] tensor containing a batch of B scores matrices
            each with T blocks and S flipflop transitions scores, where
            S = 2 * nbase * (nbase + 1)

    Returns:
        (trans, fwd_fact, bwd_fact) tensors of shape [T, N, S],
            [T + 1, B] and [T + 1, B, 1]. fwd_fact and bwd_fact are the scaling
            factors returned by flipflop_fwd and flopflop_bwd
    """
    index = scores.device.index
    T, N, S = scores.shape
    nbase = flipflopfings.nbase_flipflop(S)

    fwd, fwd_fact = flipflop_fwd(scores)
    bwd, bwd_fact = flipflop_bwd(scores)
    scores = scores.contiguous()
    trans = torch.zeros_like(scores)
    kernel_args = (
        scores.contiguous().data_ptr(),
        fwd.data_ptr(),
        bwd.data_ptr(),
        trans.data_ptr(),
        T,
        N,
        nbase,
    )
    with cp.cuda.Device(index):
        _flipflop_make_trans(grid=(N, ), block=(2 * nbase, ), args=kernel_args)
    return trans, fwd_fact, bwd_fact
Exemplo n.º 15
0
def parse_sublayer(sublayer):
    # TODO apply additional attributes (e.g. has_bias, convolutional padding)
    if sublayer['type'] == 'convolution':
        if sublayer['activation'] != 'tanh':
            sys.stderr.write((
                'Incompatible convolutional layer activation fucntion ' +
                '({}) encountered.\n').format(sublayer['type']))
            sys.exit(1)
        sys.stderr.write((
            'Loading convolutional layer with attributes:\n\tin size: {}\n' +
            '\tout size: {}\n\twinlen: {}\n\tstride: {}\n').format(
                sublayer['insize'], sublayer['size'], sublayer['winlen'],
                sublayer['stride']))
        layer = Convolution(
            sublayer['insize'], sublayer['size'], sublayer['winlen'],
            stride=sublayer['stride'], fun=tanh)
    elif sublayer['type'] == 'LSTM':
        sys.stderr.write((
            'Loading LSTM layer with attributes:\n\tin size: {}\n' +
            '\tout size: {}\n').format(
                sublayer['insize'], sublayer['size']))
        layer = Lstm(sublayer['insize'], sublayer['size'])
    elif sublayer['type'] == 'GruMod':
        sys.stderr.write((
            'Loading GRU layer with attributes:\n\tin size: {}\n' +
            '\tout size: {}\n').format(
                sublayer['insize'], sublayer['size']))
        layer = GruMod(sublayer['insize'], sublayer['size'])
    elif sublayer['type'] == 'reverse':
        sublayer = sublayer['sublayers']
        if sublayer['type'] == 'GruMod':
            sys.stderr.write((
                'Loading Reverse GRU layer with attributes:\n\tin size: {}\n' +
                '\tout size: {}\n').format(
                    sublayer['insize'], sublayer['size']))
            layer = Reverse(GruMod(sublayer['insize'], sublayer['size']))
        elif sublayer['type'] == 'LSTM':
            sys.stderr.write((
                'Loading Reverse LSTM layer with attributes:\n' +
                '\tin size: {}\n\tout size: {}\n').format(
                    sublayer['insize'], sublayer['size']))
            layer = Reverse(Lstm(sublayer['insize'], sublayer['size']))
        else:
            sys.stderr.write((
                'Invalid reversed-time layer type ({})\n').format(
                    sublayer['type']))
            sys.exit(1)
    elif sublayer['type'] == 'GlobalNormTwoState':
        nbase = nbase_flipflop(sublayer['size'])
        sys.stderr.write((
            'Loading flip-flop layer with attributes:\n\tin size: {}\n' +
            '\tnbases: {}\n').format(sublayer['insize'], nbase))
        layer = GlobalNormFlipFlop(sublayer['insize'], nbase)
    elif sublayer['type'] == 'GlobalNormTwoStateCatMod':
        output_alphabet = sublayer['output_alphabet']
        curr_can_base = 0
        collapse_alphabet = ''
        for can_i_nmod in sublayer['can_nmods']:
            collapse_alphabet += output_alphabet[curr_can_base] * (
                can_i_nmod + 1)
            curr_can_base += can_i_nmod + 1
        alphabet_info = alphabet.AlphabetInfo(
            output_alphabet, collapse_alphabet,
            sublayer['modified_base_long_names'], do_reorder=False)
        sys.stderr.write((
            'Loading modified bases flip-flop layer with attributes:\n' +
            '\tin size: {}\n\tmod bases: {}\n').format(
                sublayer['insize'], alphabet_info.mod_long_names))
        layer = GlobalNormFlipFlopCatMod(sublayer['insize'], alphabet_info)
    else:
        sys.stderr.write('Encountered invalid layer type ({}).\n'.format(
            sublayer['type']))
        sys.exit(1)

    layer = set_params(layer, sublayer['params'], sublayer['type'])

    return layer
Exemplo n.º 16
0
 def test_bwd_equals_global_norm(self):
     nbase = int(nbase_flipflop(self.weights.shape[1]))
     bwd, _ = decodeutil.backward(self.weights)
     bwd_score = float(logsumexp(bwd[0, :nbase], axis=0))
     self.assertAlmostEqual(bwd_score, self.tensor_score, places=5)