Ejemplo n.º 1
0
class StringKernel(Kernel):
    """
    Code to run the SSK of Moss et al. 2020 with gpflow
    
   with hyperparameters:
    1) match_decay float
        decrease the contribution of long subsequences
    2) gap_decay float
        decrease the contribtuion of subsequences with large gaps (penalize non-contiguous)
    3) max_subsequence_length int 
        largest subsequence considered
    4) max_occurence_length int
        longest non-contiguous occurences of subsequences considered (max_occurence_length > max_subsequence_length)
    We calculate gradients for match_decay and gap_decay w.r.t kernel hyperparameters following Beck (2017)
    We recommend normalize = True to allow meaningful comparrison of strings of different length
    """
    def __init__(self,
                 active_dims=[0],
                 gap_decay=0.1,
                 match_decay=0.9,
                 max_subsequence_length=3,
                 max_occurence_length=10,
                 alphabet=[],
                 maxlen=0,
                 normalize=True,
                 batch_size=1000):
        super().__init__(active_dims=active_dims)
        # constrain kernel params to between 0 and 1
        self.logistic_gap = tfb.Chain([
            tfb.AffineScalar(shift=tf.cast(0, tf.float64),
                             scale=tf.cast(1, tf.float64)),
            tfb.Sigmoid()
        ])
        self.logisitc_match = tfb.Chain([
            tfb.AffineScalar(shift=tf.cast(0, tf.float64),
                             scale=tf.cast(1, tf.float64)),
            tfb.Sigmoid()
        ])
        self.gap_decay_param = Parameter(gap_decay,
                                         transform=self.logistic_gap,
                                         name="gap_decay")
        self.match_decay_param = Parameter(match_decay,
                                           transform=self.logisitc_match,
                                           name="match_decay")
        self.max_subsequence_length = max_subsequence_length
        self.max_occurence_length = max_occurence_length
        self.alphabet = alphabet
        self.maxlen = maxlen
        self.normalize = normalize
        self.batch_size = batch_size
        self.symmetric = False

        # use will use copies of the kernel params to stop building expensive computation graph
        # we instead efficientely calculate gradients using dynamic programming
        # These params are updated at every call to K and K_diag (to check if parameters have been updated)
        self.match_decay = self.match_decay_param.numpy()
        self.gap_decay = self.gap_decay_param.numpy()
        self.match_decay_unconstrained = self.match_decay_param.unconstrained_variable.numpy(
        )
        self.gap_decay_unconstrained = self.gap_decay_param.unconstrained_variable.numpy(
        )

        # initialize helful construction matricies to be lazily computed once needed
        self.D = None
        self.dD_dgap = None

        # build a lookup table of the alphabet to encode input strings
        self.table = tf.lookup.StaticHashTable(
            initializer=tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(["PAD"] + alphabet),
                values=tf.constant(range(0,
                                         len(alphabet) + 1)),
            ),
            default_value=0)

    def K_diag(self, X):
        r"""
        Calc just the diagonal elements of a kernel matrix
        """

        # check if string is not longer than max length
        if tf.reduce_max(tf.strings.length(X)) + 1 > 2 * self.maxlen:
            raise ValueError(
                "An input string is longer that max-length so refit the kernel with a larger maxlen param"
            )

        if self.normalize:
            # if normalizing then diagonal will just be ones
            return tf.cast(tf.fill(tf.shape(X)[:-1], 1), tf.float64)
        else:
            # otherwise have to calc kernel elements
            # Turn inputs into lists of integers using one-hot embedding and pad until all same length
            X = tf.strings.split(tf.squeeze(X, 1)).to_tensor(
                "PAD", shape=[None, self.maxlen])
            X = self.table.lookup(X)

            # prep required quantities and check kernel parameters
            self._precalc()

            # Proceed with kernel matrix calculations in batches
            k_results = tf.TensorArray(tf.float64,
                                       size=0,
                                       dynamic_size=True,
                                       infer_shape=False)

            num_batches = tf.math.ceil(tf.shape(X)[0] / self.batch_size)
            # iterate through batches
            for i in tf.range(
                    tf.cast(tf.math.ceil(tf.shape(X)[0] / self.batch_size),
                            dtype=tf.int32)):
                X_batch = X[self.batch_size * i:self.batch_size * (i + 1)]
                k_results = k_results.write(k_results.size(),
                                            self._k(X_batch, X_batch))

            # collect all batches
            return tf.reshape(k_results.concat(), (-1, ))

    def K(self, X, X2=None):
        r"""
        Now we calculate the kernel values and kernel gradients
        Efficientely calculating kernel gradients requires dynamic programming 
        and so we 'turn off' autograd and calculate manually

        We currently only bother calculating the kernel gradients for gram matricies
        i.e (when X=X2) as required when fitting the model.
        For predictions (where X != X2) we do not calculate gradients
        """

        if X2 is None:
            self.symmetric = True
            k_results = self.K_calc(X, X)
        else:
            self.symmetric = False
            k_results = self.K_calc(X, X2)

        return k_results

    def _precalc(self):
        r"""
        Update stored kernel params (incase they have changed)
        and precalc D and dD_dgap as required for kernel calcs
        following notation from Beck (2017)
        """
        self.match_decay = self.match_decay_param.numpy()
        self.gap_decay = self.gap_decay_param.numpy()
        self.match_decay_unconstrained = self.match_decay_param.unconstrained_variable.numpy(
        )
        self.gap_decay_unconstrained = self.gap_decay_param.unconstrained_variable.numpy(
        )

        tril = tf.linalg.band_part(
            tf.ones((self.maxlen, self.maxlen), dtype=tf.float64), -1, 0)
        # get upper triangle matrix of increasing intergers
        values = tf.TensorArray(tf.int32, size=self.maxlen)
        for i in tf.range(self.maxlen):
            values = values.write(i, tf.range(-i - 1, self.maxlen - 1 - i))
        power = tf.cast(values.stack(), tf.float64)
        values.close()
        power = tf.linalg.band_part(power, 0, -1) - tf.linalg.band_part(
            power, 0, 0) + tril
        tril = tf.transpose(
            tf.linalg.band_part(
                tf.ones((self.maxlen, self.maxlen), dtype=tf.float64),
                self.max_occurence_length, 0)) - tf.eye(self.maxlen,
                                                        dtype=tf.float64)
        gaps = tf.fill([self.maxlen, self.maxlen], self.gap_decay)

        self.D = tf.pow(gaps * tril, power)
        self.dD_dgap = tf.pow((tril * gaps), (power - 1.0)) * tril * power

    @tf.custom_gradient
    def K_calc(self, X, X2):
        r"""
        Calc the elements of the kernel matrix (and gradients if symmetric)
        """

        # check if input strings are longer than max allowed length
        if (tf.reduce_max(tf.strings.length(X)) + 1 > 2 * self.maxlen) or (
                tf.reduce_max(tf.strings.length(X2)) + 1 > 2 * self.maxlen):
            raise ValueError(
                "An input string is longer that max-length so refit the kernel with a larger maxlen param"
            )

        # Turn our inputs into lists of integers using one-hot embedding
        # first split up strings and pad to fixed length and prep for gpu
        # pad until all have length of self.maxlen
        X = tf.strings.split(tf.squeeze(X, 1)).to_tensor(
            "PAD", shape=[None, self.maxlen])
        X = self.table.lookup(X)
        if self.symmetric:
            X2 = X
        else:
            # pad until all have length of self.maxlen
            X2 = tf.strings.split(tf.squeeze(X2, 1)).to_tensor(
                "PAD", shape=[None, self.maxlen])
            X2 = self.table.lookup(X2)

        # get the decay tensors D and dD_dgap
        self._precalc()

        # get indicies of all possible pairings from X and X2
        # this way allows maximum number of kernel calcs to be squished onto the GPU (rather than just doing individual rows of gram)
        indicies_2, indicies_1 = tf.meshgrid(tf.range(0,
                                                      tf.shape(X2)[0]),
                                             tf.range(0,
                                                      tf.shape(X)[0]))
        indicies = tf.concat(
            [tf.reshape(indicies_1, (-1, 1)),
             tf.reshape(indicies_2, (-1, 1))],
            axis=1)
        # if symmetric then only calc upper matrix (fill in rest later)
        if self.symmetric:
            indicies = tf.boolean_mask(
                indicies, tf.greater_equal(indicies[:, 1], indicies[:, 0]))
        # make kernel calcs in batches
        num_batches = tf.math.ceil(tf.shape(indicies)[0] / self.batch_size)
        # iterate through batches

        if self.symmetric:
            k_results = tf.TensorArray(tf.float64,
                                       size=0,
                                       dynamic_size=True,
                                       infer_shape=False)
            gap_grads = tf.TensorArray(tf.float64,
                                       size=0,
                                       dynamic_size=True,
                                       infer_shape=False)
            match_grads = tf.TensorArray(tf.float64,
                                         size=0,
                                         dynamic_size=True,
                                         infer_shape=False)
            for i in tf.range(
                    tf.cast(tf.math.ceil(
                        tf.shape(indicies)[0] / self.batch_size),
                            dtype=tf.int32)):
                indicies_batch = indicies[self.batch_size * i:self.batch_size *
                                          (i + 1)]
                X_batch = tf.gather(X, indicies_batch[:, 0], axis=0)
                X2_batch = tf.gather(X2, indicies_batch[:, 1], axis=0)
                results = self._k_grads(X_batch, X2_batch)
                k_results = k_results.write(k_results.size(), results[0])
                gap_grads = gap_grads.write(gap_grads.size(), results[1])
                match_grads = match_grads.write(match_grads.size(), results[2])
            # combine indivual kernel results
            k_results = tf.reshape(k_results.concat(), [1, -1])
            gap_grads = tf.reshape(gap_grads.concat(), [1, -1])
            match_grads = tf.reshape(match_grads.concat(), [1, -1])
        else:
            k_results = tf.TensorArray(tf.float64,
                                       size=0,
                                       dynamic_size=True,
                                       infer_shape=False)
            for i in tf.range(
                    tf.cast(tf.math.ceil(
                        tf.shape(indicies)[0] / self.batch_size),
                            dtype=tf.int32)):
                indicies_batch = indicies[self.batch_size * i:self.batch_size *
                                          (i + 1)]
                X_batch = tf.gather(X, indicies_batch[:, 0], axis=0)
                X2_batch = tf.gather(X2, indicies_batch[:, 1], axis=0)
                k_results = k_results.write(k_results.size(),
                                            self._k(X_batch, X2_batch))
            # combine indivual kernel results
            k_results = tf.reshape(k_results.concat(), [1, -1])

        # put results into the right places in the gram matrix
        # if symmetric then only put in top triangle (inc diag)
        if self.symmetric:
            mask = tf.linalg.band_part(
                tf.ones((tf.shape(X)[0], tf.shape(X)[0]), dtype=tf.int64), 0,
                -1)
            non_zero = tf.not_equal(mask, tf.constant(0, dtype=tf.int64))
            indices = tf.where(
                non_zero)  # Extracting the indices of upper triangle elements
            out = tf.SparseTensor(indices,
                                  tf.squeeze(k_results),
                                  dense_shape=tf.cast(
                                      (tf.shape(X)[0], tf.shape(X)[0]),
                                      dtype=tf.int64))
            k_results = tf.sparse.to_dense(out)
            out = tf.SparseTensor(indices,
                                  tf.squeeze(gap_grads),
                                  dense_shape=tf.cast(
                                      (tf.shape(X)[0], tf.shape(X)[0]),
                                      dtype=tf.int64))
            gap_grads = tf.sparse.to_dense(out)
            out = tf.SparseTensor(indices,
                                  tf.squeeze(match_grads),
                                  dense_shape=tf.cast(
                                      (tf.shape(X)[0], tf.shape(X)[0]),
                                      dtype=tf.int64))
            match_grads = tf.sparse.to_dense(out)

            #add in mising elements (lower diagonal)
            k_results = k_results + tf.linalg.set_diag(
                tf.transpose(k_results),
                tf.zeros(tf.shape(X)[0], dtype=tf.float64))
            gap_grads = gap_grads + tf.linalg.set_diag(
                tf.transpose(gap_grads),
                tf.zeros(tf.shape(X)[0], dtype=tf.float64))
            match_grads = match_grads + tf.linalg.set_diag(
                tf.transpose(match_grads),
                tf.zeros(tf.shape(X)[0], dtype=tf.float64))
        else:
            k_results = tf.reshape(
                k_results, [tf.shape(X)[0], tf.shape(X2)[0]])

        # normalize if required
        if self.normalize:
            if self.symmetric:
                # if symmetric then can extract normalization terms from gram
                X_diag_Ks = tf.linalg.diag_part(k_results)
                X_diag_gap_grads = tf.linalg.diag_part(gap_grads)
                X_diag_match_grads = tf.linalg.diag_part(match_grads)

                # norm for kernel entries
                norm = tf.tensordot(X_diag_Ks, X_diag_Ks, axes=0)
                k_results = tf.divide(k_results, tf.sqrt(norm))
                # norm for gap_decay and match_decay grads
                diff_gap = tf.divide(
                    tf.tensordot(X_diag_gap_grads, X_diag_Ks, axes=0) +
                    tf.tensordot(X_diag_Ks, X_diag_gap_grads, axes=0),
                    2 * norm)
                diff_match = tf.divide(
                    tf.tensordot(X_diag_match_grads, X_diag_Ks, axes=0) +
                    tf.tensordot(X_diag_Ks, X_diag_match_grads, axes=0),
                    2 * norm)
                gap_grads = tf.divide(gap_grads, tf.sqrt(norm)) - tf.multiply(
                    k_results, diff_gap)
                match_grads = tf.divide(match_grads,
                                        tf.sqrt(norm)) - tf.multiply(
                                            k_results, diff_match)

            else:
                # if not symmetric then need to calculate some extra kernel calcs
                # get diagonal kernel calcs for X1
                X_diag_Ks = tf.TensorArray(tf.float64,
                                           size=0,
                                           dynamic_size=True,
                                           infer_shape=False)
                num_batches = tf.math.ceil(tf.shape(X)[0] / self.batch_size)
                # iterate through batches
                for i in tf.range(
                        tf.cast(tf.math.ceil(tf.shape(X)[0] / self.batch_size),
                                dtype=tf.int32)):
                    X_batch = X[self.batch_size * i:self.batch_size * (i + 1)]
                    X_diag_Ks = X_diag_Ks.write(X_diag_Ks.size(),
                                                self._k(X_batch, X_batch))
                # collect up all batches
                X_diag_Ks = tf.reshape(X_diag_Ks.concat(), (-1, ))

                # get diagonal kernel calcs for X2
                X2_diag_Ks = tf.TensorArray(tf.float64,
                                            size=0,
                                            dynamic_size=True,
                                            infer_shape=False)
                num_batches = tf.math.ceil(tf.shape(X2)[0] / self.batch_size)
                # iterate through batches
                for i in tf.range(
                        tf.cast(tf.math.ceil(
                            tf.shape(X2)[0] / self.batch_size),
                                dtype=tf.int32)):
                    X2_batch = X2[self.batch_size * i:self.batch_size *
                                  (i + 1)]
                    X2_diag_Ks = X2_diag_Ks.write(X2_diag_Ks.size(),
                                                  self._k(X2_batch, X2_batch))
                # collect up all batches
                X2_diag_Ks = tf.reshape(X2_diag_Ks.concat(), (-1, ))

                # norm for kernel entries
                norm = tf.tensordot(X_diag_Ks, X2_diag_Ks, axes=0)
                k_results = tf.divide(k_results, tf.sqrt(norm))

        def grad(dy, variables=None):
            if self.symmetric:
                # get gradients of unconstrained params
                grads = {}
                grads['gap_decay:0'] = tf.reduce_sum(
                    tf.multiply(
                        dy,
                        gap_grads * tf.math.exp(
                            self.logistic_gap.forward_log_det_jacobian(
                                self.gap_decay_unconstrained, 0))))
                grads['match_decay:0'] = tf.reduce_sum(
                    tf.multiply(
                        dy,
                        match_grads * tf.math.exp(
                            self.logisitc_match.forward_log_det_jacobian(
                                self.match_decay_unconstrained, 0))))
                gradient = [grads[v.name] for v in variables]
                return ((None, None), gradient)
            else:
                return ((None, None), [None, None])

        return k_results, grad

    def _k_grads(self, X1, X2):
        r"""
        Vectorized kernel calc and kernel grad calc.
        Following notation from Beck (2017), i.e have tensors S,D,Kpp,Kp
        Input is two tensors of shape (# strings , # characters)
        and we calc the pair-wise kernel calcs between the elements (i.e n kern calcs for two lists of length n)
        D is the tensor than unrolls the recursion and allows vecotrizaiton
        """

        # turn into one-hot  i.e. shape (# strings, #characters+1, alphabet size)
        X1 = tf.one_hot(X1, len(self.alphabet) + 1, dtype=tf.float64)
        X2 = tf.one_hot(X2, len(self.alphabet) + 1, dtype=tf.float64)
        # remove the ones in the first column that encode the padding (i.e we dont want them to count as a match)
        paddings = tf.constant([[0, 0], [0, 0], [0, len(self.alphabet)]])
        X1 = X1 - tf.pad(tf.expand_dims(X1[:, :, 0], 2), paddings, "CONSTANT")
        X2 = X2 - tf.pad(tf.expand_dims(X2[:, :, 0], 2), paddings, "CONSTANT")
        # store squared match coef
        match_sq = tf.square(self.match_decay)
        # Make S: the similarity tensor of shape (# strings, #characters, # characters)
        S = tf.matmul(X1, tf.transpose(X2, perm=(0, 2, 1)))
        # Main loop, where Kp, Kpp values and gradients are calculated.
        Kp = tf.TensorArray(tf.float64,
                            size=0,
                            dynamic_size=True,
                            clear_after_read=False)
        dKp_dgap = tf.TensorArray(tf.float64,
                                  size=0,
                                  dynamic_size=True,
                                  clear_after_read=False)
        dKp_dmatch = tf.TensorArray(tf.float64,
                                    size=0,
                                    dynamic_size=True,
                                    clear_after_read=False)
        Kp = Kp.write(
            Kp.size(),
            tf.ones(shape=tf.stack([tf.shape(X1)[0], self.maxlen,
                                    self.maxlen]),
                    dtype=tf.float64))
        dKp_dgap = dKp_dgap.write(
            dKp_dgap.size(),
            tf.zeros(shape=tf.stack(
                [tf.shape(X1)[0], self.maxlen, self.maxlen]),
                     dtype=tf.float64))
        dKp_dmatch = dKp_dmatch.write(
            dKp_dmatch.size(),
            tf.zeros(shape=tf.stack(
                [tf.shape(X1)[0], self.maxlen, self.maxlen]),
                     dtype=tf.float64))

        # calc subkernels for each subsequence length
        for i in tf.range(0, self.max_subsequence_length - 1):

            Kp_temp = tf.multiply(S, Kp.read(i))
            Kp_temp0 = match_sq * Kp_temp
            Kp_temp1 = tf.matmul(Kp_temp0, self.D)
            Kp_temp2 = tf.matmul(self.D, Kp_temp1, transpose_a=True)
            Kp = Kp.write(Kp.size(), Kp_temp2)

            dKp_dgap_temp_1 = tf.matmul(self.dD_dgap,
                                        Kp_temp1,
                                        transpose_a=True)
            dKp_dgap_temp_2 = tf.multiply(S, dKp_dgap.read(i))
            dKp_dgap_temp_2 = dKp_dgap_temp_2 * match_sq
            dKp_dgap_temp_2 = tf.matmul(dKp_dgap_temp_2, self.D)
            dKp_dgap_temp_2 = dKp_dgap_temp_2 + tf.matmul(
                Kp_temp0, self.dD_dgap)
            dKp_dgap_temp_2 = tf.matmul(self.D,
                                        dKp_dgap_temp_2,
                                        transpose_a=True)
            dKp_dgap = dKp_dgap.write(dKp_dgap.size(),
                                      dKp_dgap_temp_1 + dKp_dgap_temp_2)

            dKp_dmatch_temp_1 = 2 * tf.divide(Kp_temp2, self.match_decay)
            dKp_dmatch_temp_2 = tf.multiply(S, dKp_dmatch.read(i))
            dKp_dmatch_temp_2 = dKp_dmatch_temp_2 * match_sq
            dKp_dmatch_temp_2 = tf.matmul(dKp_dmatch_temp_2, self.D)
            dKp_dmatch_temp_2 = tf.matmul(self.D,
                                          dKp_dmatch_temp_2,
                                          transpose_a=True)
            dKp_dmatch = dKp_dmatch.write(
                dKp_dmatch.size(), dKp_dmatch_temp_1 + dKp_dmatch_temp_2)

        # Final calculation. We gather all Kps
        Kp_stacked = Kp.stack()
        Kp.close()
        dKp_dgap_stacked = dKp_dgap.stack()
        dKp_dgap.close()
        dKp_dmatch_stacked = dKp_dmatch.stack()
        dKp_dmatch.close()

        # get k
        temp = tf.multiply(S, Kp_stacked)
        temp = tf.reduce_sum(temp, -1)
        sum2 = tf.reduce_sum(temp, -1)
        Ki = sum2 * match_sq
        k = tf.reduce_sum(Ki, 0)
        k = tf.expand_dims(k, 1)

        # get gap decay grads
        temp = tf.multiply(S, dKp_dgap_stacked)
        temp = tf.reduce_sum(temp, -1)
        temp = tf.reduce_sum(temp, -1)
        temp = temp * match_sq
        dk_dgap = tf.reduce_sum(temp, 0)
        dk_dgap = tf.expand_dims(dk_dgap, 1)

        # get match decay grads
        temp = tf.multiply(S, dKp_dmatch_stacked)
        temp = tf.reduce_sum(temp, -1)
        temp = tf.reduce_sum(temp, -1)
        temp = temp * match_sq
        temp = temp + 2 * self.match_decay * sum2
        dk_dmatch = tf.reduce_sum(temp, 0)
        dk_dmatch = tf.expand_dims(dk_dmatch, 1)

        return k, dk_dgap, dk_dmatch

    def _k(self, X1, X2):
        r"""
        Vectorized kernel calc.
        Following notation from Beck (2017), i.e have tensors S,D,Kpp,Kp
        Input is two tensors of shape (# strings , # characters)
        and we calc the pair-wise kernel calcs between the elements (i.e n kern calcs for two lists of length n)
        D is the tensor than unrolls the recursion and allows vecotrizaiton
        """

        # turn into one-hot  i.e. shape (# strings, #characters+1, alphabet size)
        X1 = tf.one_hot(X1, len(self.alphabet) + 1, dtype=tf.float64)
        X2 = tf.one_hot(X2, len(self.alphabet) + 1, dtype=tf.float64)
        # remove the ones in the first column that encode the padding (i.e we dont want them to count as a match)
        paddings = tf.constant([[0, 0], [0, 0], [0, len(self.alphabet)]])
        X1 = X1 - tf.pad(tf.expand_dims(X1[:, :, 0], 2), paddings, "CONSTANT")
        X2 = X2 - tf.pad(tf.expand_dims(X2[:, :, 0], 2), paddings, "CONSTANT")
        # store squared match coef
        match_sq = tf.square(self.match_decay)
        # Make S: the similarity tensor of shape (# strings, #characters, # characters)
        S = tf.matmul(X1, tf.transpose(X2, perm=(0, 2, 1)))
        # Main loop, where Kp, Kpp values and gradients are calculated.
        Kp = tf.TensorArray(tf.float64,
                            size=0,
                            dynamic_size=True,
                            clear_after_read=False)
        Kp = Kp.write(
            Kp.size(),
            tf.ones(shape=tf.stack([tf.shape(X1)[0], self.maxlen,
                                    self.maxlen]),
                    dtype=tf.float64))

        # calc subkernels for each subsequence length
        for i in tf.range(0, self.max_subsequence_length - 1):
            temp = tf.multiply(S, Kp.read(i))
            temp = tf.matmul(temp, self.D)
            temp = tf.matmul(self.D, temp, transpose_a=True)
            temp = match_sq * temp
            Kp = Kp.write(Kp.size(), temp)

        # Final calculation. We gather all Kps
        Kp_stacked = Kp.stack()
        Kp.close()

        # Get k
        aux = tf.multiply(S, Kp_stacked)
        aux = tf.reduce_sum(aux, -1)
        sum2 = tf.reduce_sum(aux, -1)
        Ki = tf.multiply(sum2, match_sq)
        k = tf.reduce_sum(Ki, 0)
        k = tf.expand_dims(k, 1)

        return k
Ejemplo n.º 2
0
class Batch_simple_SSK(Kernel):
    """
   with hyperparameters:
    1) match_decay float
        decrease the contribution of long subsequences
    3) max_subsequence_length int 
        largest subsequence considered
    """

    def __init__(self,active_dims=[0],decay=0.1,max_subsequence_length=3,
                 alphabet = [], maxlen=0, batch_size=100):
        super().__init__(active_dims=active_dims)
        # constrain decay kernel params to between 0 and 1
        self.logistic = tfb.Chain([tfb.Shift(tf.cast(0,tf.float64))(tfb.Scale(tf.cast(1,tf.float64))),tfb.Sigmoid()])
        self.decay_param= Parameter(decay, transform=self.logistic ,name="decay")

        # use will use copies of the kernel params to stop building expensive computation graph
        # we instead efficientely calculate gradients using dynamic programming
        # These params are updated at every call to K and K_diag (to check if parameters have been updated)
        self.decay = self.decay_param.numpy()

        self.decay_unconstrained = self.decay_param.unconstrained_variable.numpy()

        self.order_coefs=tf.ones(max_subsequence_length,dtype=tf.float64)
        
        # store additional kernel parameters
        self.max_subsequence_length = tf.constant(max_subsequence_length)
        self.alphabet =  tf.constant(alphabet)
        self.alphabet_size=tf.shape(self.alphabet)[0]
        self.maxlen =  tf.constant(maxlen)
        self.batch_size = tf.constant(batch_size)

        # build a lookup table of the alphabet to encode input strings
        self.table = tf.lookup.StaticHashTable(
            initializer=tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(["PAD"]+alphabet),
                values=tf.constant(range(0,len(alphabet)+1)),),default_value=0)

        # initialize helful construction matricies to be lazily computed once needed
        self.D = None
        self.dD_dgap = None


    def K_diag(self, X):
        r"""
        The diagonal elements of the string kernel are always unity (due to normalisation)
        """
        return tf.ones(tf.shape(X)[:-1],dtype=tf.float64)



    def K(self, X1, X2=None):
        r"""
        Vectorized kernel calc.
        Following notation from Beck (2017), i.e have tensors S,D,Kpp,Kp
        Input is two tensors of shape (# strings , # characters)
        and we calc the pair-wise kernel calcs between the elements (i.e n kern calcs for two lists of length n)
        D is the tensor than unrolls the recursion and allows vecotrizaiton
        """

        # Turn our inputs into lists of integers using one-hot embedding
        # first split up strings and pad to fixed length and prep for gpu
        # pad until all have length of self.maxlen
        # turn into one-hot  i.e. shape (# strings, #characters+1, alphabet size)
        X1 = tf.strings.split(tf.squeeze(X1,1)).to_tensor("PAD",shape=[None,self.maxlen])
        X1 = self.table.lookup(X1)
        # keep track of original input sizes
        X1_shape = tf.shape(X1)[0]
        X1 = tf.one_hot(X1,self.alphabet_size+1,dtype=tf.float64)
        if X2 is None:
            X2 = X1
            X2_shape = X1_shape
            self.symmetric = True
        else:
            self.symmetric = False
            X2 = tf.strings.split(tf.squeeze(X2,1)).to_tensor("PAD",shape=[None,self.maxlen])
            X2 = self.table.lookup(X2)
            X2_shape = tf.shape(X2)[0]
            X2 = tf.one_hot(X2,self.alphabet_size+1,dtype=tf.float64)
  
        # prep the decay tensors 
        self._precalc()
      


        # combine all target strings and remove the ones in the first column that encode the padding (i.e we dont want them to count as a match)
        X_full = tf.concat([X1,X2],0)[:,:,1:]

        # get indicies of all possible pairings from X and X2
        # this way allows maximum number of kernel calcs to be squished onto the GPU (rather than just doing individual rows of gram)
        indicies_2, indicies_1 = tf.meshgrid(tf.range(0, X1_shape ),tf.range(X1_shape , tf.shape(X_full)[0]))
        indicies = tf.concat([tf.reshape(indicies_1,(-1,1)),tf.reshape(indicies_2,(-1,1))],axis=1)
        if self.symmetric:
            # if symmetric then only calc upper matrix (fill in rest later)
            indicies = tf.boolean_mask(indicies,tf.greater_equal(indicies[:,1]+ X1_shape ,indicies[:,0]))
        else:
            # if not symmetric need to calculate some extra kernel evals for the normalization later on
            indicies = tf.concat([indicies,tf.tile(tf.expand_dims(tf.range(tf.shape(X_full)[0]),1),(1,2))],0)

        # make kernel calcs in batches
        num_batches = tf.cast(tf.math.ceil(tf.shape(indicies)[0]/self.batch_size),dtype=tf.int32)
        k_split =  tf.TensorArray(tf.float64, size=num_batches,clear_after_read=False,infer_shape=False)
        

        # iterate through batches
        for j in tf.range(num_batches):
            # collect strings for this batch
            indicies_batch = indicies[self.batch_size*j:self.batch_size*(j+1)]
            X_batch = tf.gather(X_full,indicies_batch[:,0],axis=0)
            X2_batch = tf.gather(X_full,indicies_batch[:,1],axis=0)

            # Make S: the similarity tensor of shape (# strings, #characters, # characters)
            #S = tf.matmul( tf.matmul(X_batch,self.sim),tf.transpose(X2_batch,perm=(0,2,1)))
            S = tf.matmul(X_batch,tf.transpose(X2_batch,perm=(0,2,1)))
            # collect results for the batch
            result = self.kernel_calc(S)
            k_split = k_split.write(j,result)

        # combine batch results
        k = tf.expand_dims(k_split.concat(),1)
        k_split.close()

        # put results into the right places in the gram matrix and normalize
        if self.symmetric:
            # if symmetric then only put in top triangle (inc diag)
            mask = tf.linalg.band_part(tf.ones((X1_shape,X2_shape),dtype=tf.int64), 0, -1)
            non_zero = tf.not_equal(mask, tf.constant(0, dtype=tf.int64))
            
            # Extracting the indices of upper triangle elements
            indices = tf.where(non_zero)
            out = tf.SparseTensor(indices,tf.squeeze(k),dense_shape=tf.cast((X1_shape,X2_shape),dtype=tf.int64))
            k_results = tf.sparse.to_dense(out)
            
            # add in mising elements (lower diagonal)
            k_results = k_results + tf.linalg.set_diag(tf.transpose(k_results),tf.zeros(X1_shape,dtype=tf.float64))
            
            # normalise
            X_diag_Ks = tf.linalg.diag_part(k_results)
            norm = tf.tensordot(X_diag_Ks, X_diag_Ks,axes=0)
            k_results = tf.divide(k_results, tf.sqrt(norm))
        else:

            # otherwise can just reshape into gram matrix
            # but first take extra kernel calcs off end of k and use them to normalise
            X_diag_Ks = tf.reshape(k[X1_shape*X2_shape:X1_shape*X2_shape+X1_shape],(-1,))
            X2_diag_Ks = tf.reshape(k[-X2_shape:],(-1,))
            k = k[0:X1_shape*X2_shape]
            k_results = tf.transpose(tf.reshape(k,[X2_shape,X1_shape]))
            # normalise
            norm = tf.tensordot(X_diag_Ks, X2_diag_Ks,axes=0)
            k_results = tf.divide(k_results, tf.sqrt(norm))


        return k_results


    def _precalc(self):
        r"""
        Update stored kernel params (incase they have changed)
        and precalc D and dD_dgap as required for kernel calcs
        following notation from Beck (2017)
        """
        self.decay = self.decay_param.numpy()
        self.decay_unconstrained = self.decay_param.unconstrained_variable.numpy()

        tril =  tf.linalg.band_part(tf.ones((self.maxlen,self.maxlen),dtype=tf.float64), -1, 0)
        # get upper triangle matrix of increasing intergers
        values = tf.TensorArray(tf.int32, size= self.maxlen)
        for i in tf.range(self.maxlen):
            values = values.write(i,tf.range(-i-1,self.maxlen-1-i)) 
        power = tf.cast(values.stack(),tf.float64)
        values.close()
        power = tf.linalg.band_part(power, 0, -1) - tf.linalg.band_part(power, 0, 0) + tril
        tril = tf.transpose(tf.linalg.band_part(tf.ones((self.maxlen,self.maxlen),dtype=tf.float64), -1, 0))-tf.eye(self.maxlen,dtype=tf.float64)
        gaps = tf.fill([self.maxlen, self.maxlen],self.decay)
        
        self.D = tf.pow(gaps*tril, power)
        self.dD_dgap = tf.pow((tril * gaps), (power - 1.0)) * tril * power



    @tf.custom_gradient
    def kernel_calc(self,S):

        # fake computations to ensure we take the custom gradients for these two params
        a = tf.square(self.decay_param)

        if self.symmetric:
            k, dk_dgap = tf.stop_gradient(self.kernel_calc_with_grads(S))
        else:
            k = tf.stop_gradient(self.kernel_calc_without_grads(S))


        def grad(dy, variables=None):
            # get gradients of unconstrained params
            grads= {}
            if self.symmetric:
                grads['decay:0'] = tf.reduce_sum(tf.multiply(dy,dk_dgap*tf.math.exp(self.logistic.forward_log_det_jacobian(self.decay_unconstrained,0))))
                gradient = [grads[v.name] for v in variables]
            else:
                gradient = [None for v in variables]
            return ((None),gradient)


        return k, grad

    def kernel_calc_without_grads(self,S):

        # store squared match coef for easier calc later
        match_sq = tf.square(self.decay)


        # calc subkernels for each subsequence length (See Moss et al. 2020 for notation)
        Kp = tf.TensorArray(tf.float64,size=self.max_subsequence_length,clear_after_read=False)

        # fill in first entries
        Kp = Kp.write(0, tf.ones(shape=tf.stack([tf.shape(S)[0], self.maxlen,self.maxlen]), dtype=tf.float64))

        # calculate dynamic programs
        for i in tf.range(self.max_subsequence_length-1):
            Kp_temp = tf.multiply(S, Kp.read(i))
            Kp_temp0 =  match_sq * Kp_temp
            Kp_temp1 = tf.matmul(Kp_temp0,self.D)
            Kp_temp2 = tf.matmul(self.D,Kp_temp1,transpose_a=True)
            Kp = Kp.write(i+1,Kp_temp2)

        # Final calculation. We gather all Kps 
        Kp_stacked = Kp.stack()
        Kp.close()

        # combine and get overall kernel
        aux = tf.multiply(S, Kp_stacked)
        aux = tf.reduce_sum(aux, -1)
        sum2 = tf.reduce_sum(aux, -1)
        Ki = sum2 * match_sq
        k = tf.linalg.matvec(tf.transpose(Ki),self.order_coefs)

        return k

    
    def kernel_calc_with_grads(self,S):
        # store squared match coef for easier calc later
        match_sq = tf.square(self.decay)
        # calc subkernels for each subsequence length (See Moss et al. 2020 for notation)
        Kp = tf.TensorArray(tf.float64,size=self.max_subsequence_length,clear_after_read=False)
        dKp_dgap = tf.TensorArray(tf.float64, size=self.max_subsequence_length, clear_after_read=False)

        # fill in first entries
        Kp = Kp.write(0, tf.ones(shape=tf.stack([tf.shape(S)[0], self.maxlen,self.maxlen]), dtype=tf.float64))
        dKp_dgap = dKp_dgap.write(0, tf.zeros(shape=tf.stack([tf.shape(S)[0], self.maxlen,self.maxlen]), dtype=tf.float64))

        # calculate dynamic programs
        for i in tf.range(self.max_subsequence_length-1):
            Kp_temp = tf.multiply(S, Kp.read(i))
            Kp_temp0 =  match_sq * Kp_temp
            Kp_temp1 = tf.matmul(Kp_temp0,self.D)
            Kp_temp2 = tf.matmul(self.D,Kp_temp1,transpose_a=True)
            Kp = Kp.write(i+1,Kp_temp2)

            dKp_dgap_temp_1 =  tf.matmul(self.dD_dgap,Kp_temp1,transpose_a=True)
            dKp_dgap_temp_2 =  tf.multiply(S, dKp_dgap.read(i))
            dKp_dgap_temp_2 = dKp_dgap_temp_2 * match_sq
            dKp_dgap_temp_2 = tf.matmul(dKp_dgap_temp_2,self.D)
            dKp_dgap_temp_2 = dKp_dgap_temp_2 + tf.matmul(Kp_temp0,self.dD_dgap)
            dKp_dgap_temp_2 = tf.matmul(self.D,dKp_dgap_temp_2,transpose_a=True)
            dKp_dgap = dKp_dgap.write(i+1,dKp_dgap_temp_1 + dKp_dgap_temp_2)



        # Final calculation. We gather all Kps 
        Kp_stacked = Kp.stack()
        Kp.close()
        dKp_dgap_stacked = dKp_dgap.stack()
        dKp_dgap.close()


        # combine and get overall kernel

        # get k
        aux = tf.multiply(S, Kp_stacked)
        aux = tf.reduce_sum(aux, -1)
        sum2 = tf.reduce_sum(aux, -1)
        Ki = sum2 * match_sq
        k = tf.linalg.matvec(tf.transpose(Ki),self.order_coefs)

        # get gap decay grads
        temp = tf.multiply(S, dKp_dgap_stacked)
        temp = tf.reduce_sum(temp, -1)
        temp = tf.reduce_sum(temp, -1)
        temp = temp * match_sq
        dk_dgap = tf.linalg.matvec(tf.transpose(temp),self.order_coefs)


        return k, dk_dgap