Пример #1
0
 def apply_softmax(self, wrapper):
     ub = self.upper
     lb = self.lower
     # Keep diagonal and take opposite bound for non-diagonals.
     lbs = tf.matrix_diag(lb) + tf.expand_dims(ub,
                                               axis=-2) - tf.matrix_diag(ub)
     ubs = tf.matrix_diag(ub) + tf.expand_dims(lb,
                                               axis=-2) - tf.matrix_diag(lb)
     # Get diagonal entries after softmax operation.
     ubs = tf.matrix_diag_part(tf.nn.softmax(ubs))
     lbs = tf.matrix_diag_part(tf.nn.softmax(lbs))
     return IntervalBounds(lbs, ubs)
Пример #2
0
    def _slice_cov(self, cov):
        """
        Slice the correct dimensions for use in the kernel, as indicated by
        `self.active_dims` for covariance matrices. This requires slicing the
        rows *and* columns. This will also turn flattened diagonal
        matrices into a tensor of full diagonal matrices.
        :param cov: Tensor of covariance matrices (NxDxD or NxD).
        :return: N x self.input_dim x self.input_dim.
        """
        cov = tf.cond(tf.equal(tf.rank(cov), 2), lambda: tf.matrix_diag(cov),
                      lambda: cov)

        if isinstance(self.active_dims, slice):
            cov = cov[..., self.active_dims, self.active_dims]
        else:
            cov_shape = tf.shape(cov)
            covr = tf.reshape(cov, [-1, cov_shape[-1], cov_shape[-1]])
            gather1 = tf.gather(tf.transpose(covr, [2, 1, 0]),
                                self.active_dims)
            gather2 = tf.gather(tf.transpose(gather1, [1, 0, 2]),
                                self.active_dims)
            cov = tf.reshape(
                tf.transpose(gather2, [2, 0, 1]),
                tf.concat([
                    cov_shape[:-2],
                    [len(self.active_dims),
                     len(self.active_dims)]
                ], 0))
        return cov
Пример #3
0
    def log_prob_fn(params):
      rho, alpha, sigma = tf.split(params, [num_features, 1, 1], -1)

      one = tf.ones(num_features)
      def indep(d):
        return tfd.Independent(d, 1)
      p_rho = indep(tfd.InverseGamma(5. * one, 5. * one))
      p_alpha = indep(tfd.HalfNormal([1.]))
      p_sigma = indep(tfd.HalfNormal([1.]))

      rho_shape = tf.shape(rho)
      alpha_shape = tf.shape(alpha)

      x1 = tf.expand_dims(x, -2)
      x2 = tf.expand_dims(x, -3)
      exp = -0.5 * tf.squared_difference(x1, x2)
      exp /= tf.reshape(tf.square(rho), tf.concat([rho_shape[:1], [1, 1], rho_shape[1:]], 0))
      exp = tf.reduce_sum(exp, -1, keep_dims=True)
      exp += 2. * tf.reshape(tf.log(alpha), tf.concat([alpha_shape[:1], [1, 1], alpha_shape[1:]], 0))
      exp = tf.exp(exp[Ellipsis, 0])
      exp += tf.matrix_diag(tf.tile(tf.square(sigma), [1, int(x.shape[0])]) + 1e-6)
      exp = tf.check_numerics(exp, "exp 2 has NaNs")
      with tf.control_dependencies([tf.print(exp[0], summarize=99999)]):
        exp = tf.identity(exp)

      p_y = tfd.MultivariateNormalFullCovariance(
          covariance_matrix=exp)

      log_prob = (
          p_rho.log_prob(rho) + p_alpha.log_prob(alpha) +
          p_sigma.log_prob(sigma) + p_y.log_prob(y))

      return log_prob
Пример #4
0
    def build(self, input_shape):
        input_shape = tf.TensorShape(input_shape)
        if dimension_value(input_shape[-1]) is None:
            raise ValueError("The last dimension of the inputs to `Dense` "
                             "should be defined. Found `None`.")
        last_dim = dimension_value(input_shape[-1])
        self.input_spec = tf.keras.layers.InputSpec(min_ndim=2,
                                                    axes={-1: last_dim})

        self._c = tf.get_variable(
            "c", [self._decoder_dim, self._rank],
            initializer=tf.contrib.layers.xavier_initializer(),
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            dtype=self.dtype,
            trainable=True)

        sigma = tf.matmul(self._mem_input, self._c)
        if self._sigma_norm > 0.:
            sigma = tf.nn.l2_normalize(sigma, axis=1) * self._sigma_norm
        elif self._sigma_norm == -1.:
            sigma = tf.nn.softmax(sigma / self._tau, axis=1)
        sigma_diag = tf.matrix_diag(sigma)

        self._u = tf.get_variable(
            "u", [last_dim, self._rank],
            initializer=tf.contrib.layers.xavier_initializer(),
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            dtype=self.dtype,
            trainable=True)
        self._v = tf.get_variable(
            "v", [self._rank, self.units],
            initializer=tf.contrib.layers.xavier_initializer(),
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            dtype=self.dtype,
            trainable=True)
        self.kernel = tf.einsum("ij,ajk,kl->ail", self._u, sigma_diag, self._v)
        if self._use_beam and self._beam_width:
            self.kernel = tf.contrib.seq2seq.tile_batch(
                self.kernel, multiplier=self._beam_width)

        if self.use_bias:
            self._b = self.add_weight("b",
                                      shape=[self.units, self._rank],
                                      initializer=self.bias_initializer,
                                      regularizer=self.bias_regularizer,
                                      constraint=self.bias_constraint,
                                      dtype=self.dtype,
                                      trainable=True)
            self.bias = tf.einsum("ij,aj->ai", self._b, sigma)
            if self._use_beam and self._beam_width:
                self.bias = tf.contrib.seq2seq.tile_batch(
                    self.bias, multiplier=self._beam_width)
        else:
            self.bias = None
        self.built = True
Пример #5
0
def linear_certain_activations(x_certain, A, b):
    """
    compute y = x^T A + b
    assuming x has zero variance
    """
    x_mean = x_certain
    xx = x_mean * x_mean
    y_mean = tf.matmul(x_mean, A.mean) + b.mean
    y_cov = tf.matrix_diag(tf.matmul(xx, A.var) + b.var)
    return gv.GaussianVar(y_mean, y_cov)
Пример #6
0
    def initialize(self):
        """Construct RNN params.
    """

        sigma = tf.matmul(self._mem_input, self._c)

        if self._sigma_norm > 0.:
            sigma = tf.nn.l2_normalize(sigma, axis=1) * self._sigma_norm
        elif self._sigma_norm == -1.:
            sigma = tf.nn.softmax(sigma, axis=1)
        sigma_diag = tf.matrix_diag(sigma)

        # The weight matrices.
        # {`x`: input, `h`: hidden_state}.
        # {`i`: input_gate, `j`: cell_state, `f`: forget_gate, `o`: output_gate}.

        # Weight matrix that maps input `x` to input_gate `i`.
        w_xi = tf.einsum("ij,ajk,kl->ail", self._u_xi, sigma_diag, self._v_xi)

        # Weight matrix that maps input `x` to cell_state `j`.
        w_xj = tf.einsum("ij,ajk,kl->ail", self._u_xj, sigma_diag, self._v_xj)

        # Weight matrix that maps input `x` to forget_gate `f`.
        w_xf = tf.einsum("ij,ajk,kl->ail", self._u_xf, sigma_diag, self._v_xf)

        # Weight matrix that maps input `x` to output_gate `o`.
        w_xo = tf.einsum("ij,ajk,kl->ail", self._u_xo, sigma_diag, self._v_xo)

        # Weight matrix that maps hidden_state `h` to input_gate `i`.
        w_hi = tf.einsum("ij,ajk,kl->ail", self._u_hi, sigma_diag, self._v_hi)

        # Weight matrix that maps hidden_state `h` to cell_state `j`.
        w_hj = tf.einsum("ij,ajk,kl->ail", self._u_hj, sigma_diag, self._v_hj)

        # Weight matrix that maps hidden_state `h` to forget_gate `f`.
        w_hf = tf.einsum("ij,ajk,kl->ail", self._u_hf, sigma_diag, self._v_hf)

        # Weight matrix that maps hidden_state `h` to output_gate `o`.
        w_ho = tf.einsum("ij,ajk,kl->ail", self._u_ho, sigma_diag, self._v_ho)

        w_x = tf.concat([w_xi, w_xj, w_xf, w_xo], axis=2)
        w_h = tf.concat([w_hi, w_hj, w_hf, w_ho], axis=2)

        self._weight = tf.concat([w_x, w_h], axis=1)
        self._bias = tf.einsum("ij,aj->ai", self._b, sigma)

        if self._use_beam and self._beam_width > 1:
            self._weight = tf.contrib.seq2seq.tile_batch(
                self._weight, multiplier=self._beam_width)
            self._bias = tf.contrib.seq2seq.tile_batch(
                self._bias, multiplier=self._beam_width)
Пример #7
0
    def __init__(self, priors, **kwargs):
        Prior.__init__(self)
        self.priors = priors
        self.name = kwargs.get("name", "FactPrior")
        self.nparams = len(priors)

        means = [prior.mean for prior in self.priors]
        variances = [prior.var for prior in self.priors]
        self.mean = self.log_tf(tf.stack(means, axis=-1, name="%s_mean" % self.name))
        self.var = self.log_tf(tf.stack(variances, axis=-1, name="%s_var" % self.name))
        self.std = tf.sqrt(self.var, name="%s_std" % self.name)
        self.nvertices = priors[0].nvertices

        # Define a diagonal covariance matrix for convenience
        self.cov = tf.matrix_diag(self.var, name='%s_cov' % self.name)
Пример #8
0
 def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled):
     cov_z_mean = compute_covariance_z_mean(z_mean)
     lambda_d = self.lambda_d_factor * self.lambda_od
     if self.dip_type == "i":  # Eq 6 page 4
         # mu = z_mean is [batch_size, num_latent]
         # Compute cov_p(x) [mu(x)] = E[mu*mu^T] - E[mu]E[mu]^T]
         cov_dip_regularizer = regularize_diag_off_diag_dip(
             cov_z_mean, self.lambda_od, lambda_d)
     elif self.dip_type == "ii":
         cov_enc = tf.matrix_diag(tf.exp(z_logvar))
         expectation_cov_enc = tf.reduce_mean(cov_enc, axis=0)
         cov_z = expectation_cov_enc + cov_z_mean
         cov_dip_regularizer = regularize_diag_off_diag_dip(
             cov_z, self.lambda_od, lambda_d)
     else:
         raise NotImplementedError("DIP variant not supported.")
     return kl_loss + cov_dip_regularizer
    def grad(grad_e, grad_v):
        """Gradient for SelfAdjointEigV2."""
        with tf.control_dependencies([grad_e, grad_v]):
            ediffs = tf.expand_dims(e, -2) - tf.expand_dims(e, -1)

            # Avoid NaNs from reciprocals when eigenvalues are close.
            safe_recip = tf.where(ediffs**2 < 1e-10, tf.zeros_like(ediffs),
                                  tf.reciprocal(ediffs))
            f = tf.matrix_set_diag(safe_recip, tf.zeros_like(e))
            grad_a = tf.matmul(
                v,
                tf.matmul(tf.matrix_diag(grad_e) +
                          f * tf.matmul(v, grad_v, adjoint_a=True),
                          v,
                          adjoint_b=True))
        # The forward op only depends on the lower triangular part of a, so here we
        # symmetrize and take the lower triangle
        grad_a = tf.linalg.band_part(grad_a + tf.linalg.adjoint(grad_a), -1, 0)
        grad_a = tf.linalg.set_diag(grad_a, 0.5 * tf.matrix_diag_part(grad_a))
        return grad_a
Пример #10
0
    def __init__(self, posts, **kwargs):
        FactorisedPosterior.__init__(self, posts, **kwargs)

        # The full covariance matrix is formed from the Cholesky decomposition
        # to ensure that it remains positive definite.
        #
        # To achieve this, we have to create PxP tensor variables for
        # each parameter vertex, but we then extract only the lower triangular
        # elements and train only on these. The diagonal elements
        # are constructed by the FactorisedPosterior
        if kwargs.get("init", None):
            # We are initializing from an existing posterior.
            # The FactorizedPosterior will already have extracted the mean and
            # diagonal of the covariance matrix - we need the Cholesky decomposition
            # of the covariance to initialize the off-diagonal terms
            self.log.info(" - Initializing posterior covariance from input posterior")
            _mean, cov = kwargs["init"]
            covar_init = tf.cholesky(cov)
        else:
            covar_init = tf.zeros([self.nvertices, self.nparams, self.nparams], dtype=tf.float32)

        self.off_diag_vars_base = self.log_tf(tf.Variable(covar_init, validate_shape=False,
                                                     name='%s_off_diag_vars' % self.name))
        if kwargs.get("suppress_nan", True):
            self.off_diag_vars = tf.where(tf.is_nan(self.off_diag_vars_base), tf.zeros_like(self.off_diag_vars_base), self.off_diag_vars_base)
        else:
            self.off_diag_vars = self.off_diag_vars_base
        self.off_diag_cov_chol = tf.matrix_set_diag(tf.matrix_band_part(self.off_diag_vars, -1, 0),
                                                    tf.zeros([self.nvertices, self.nparams]),
                                                    name='%s_off_diag_cov_chol' % self.name)

        # Combine diagonal and off-diagonal elements into full matrix
        self.cov_chol = tf.add(tf.matrix_diag(self.std), self.off_diag_cov_chol,
                               name='%s_cov_chol' % self.name)

        # Form the covariance matrix from the chol decomposition
        self.cov = tf.matmul(tf.transpose(self.cov_chol, perm=(0, 2, 1)), self.cov_chol,
                             name='%s_cov' % self.name)

        self.cov_chol = self.log_tf(self.cov_chol)
        self.cov = self.log_tf(self.cov)
Пример #11
0
    def __init__(self, posts, **kwargs):
        Posterior.__init__(self, -1, **kwargs)
        self.posts = posts
        self.nparams = len(self.posts)
        self.name = kwargs.get("name", "FactPost")

        means = [post.mean for post in self.posts]
        variances = [post.var for post in self.posts]
        mean = tf.stack(means, axis=-1, name="%s_mean" % self.name)
        var = tf.stack(variances, axis=-1, name="%s_var" % self.name)

        self.mean = self.log_tf(tf.identity(mean, name="%s_mean" % self.name))
        self.var = self.log_tf(tf.identity(var, name="%s_var" % self.name))
        self.std = tf.sqrt(self.var, name="%s_std" % self.name)
        self.nvertices = posts[0].nvertices

        # Covariance matrix is diagonal
        self.cov = tf.matrix_diag(self.var, name='%s_cov' % self.name)

        # Regularisation to make sure cov is invertible. Note that we do not
        # need this for a diagonal covariance matrix but it is useful for
        # the full MVN covariance which shares some of the calculations
        self.cov_reg = 1e-5*tf.eye(self.nparams)
Пример #12
0
def test_fac_init_mvn():
    """ Test factorised posterior initialized from a full MVN"""
    with tf.Session() as session:
        nparams_in = 4
        nvoxels_in = 34
        posts = []
        means = np.random.normal(5.0, 3.0, [nvoxels_in, nparams_in])
        variances = np.square(np.random.normal(2.5, 1.6, [nvoxels_in, nparams_in]))
        name = "TestFactorisedPosterior"
        for param in range(nparams_in):
            posts.append(NormalPosterior(means[:, param], variances[:, param]))

        mvn_post = MVNPosterior(posts, name=name)

        session.run(tf.global_variables_initializer())
        init=(session.run(mvn_post.mean), session.run(mvn_post.cov))
        fac_post = FactorisedPosterior(posts, init=init)

        session.run(tf.global_variables_initializer())
        assert np.allclose(session.run(fac_post.mean), session.run(mvn_post.mean))

        # Expect the covariance to just contain the diagonal elements
        diag_cov = tf.matrix_diag(tf.matrix_diag_part(mvn_post.cov))
        assert np.allclose(session.run(fac_post.cov), session.run(diag_cov))
 def sigma_tf(self, t, X, Y):  # M x 1, M x D, M x 1
     M = self.M
     D = self.D
     return tf.constant(self.snoise) * tf.matrix_diag(tf.ones([M, D]))
def multihead_invertible_1x1_conv_np(name, x, x_mask, multihead_split, inverse,
                                     dtype):
    """Multi-head 1X1 convolution on x."""
    batch_size, length, n_channels_all = common_layers.shape_list(x)
    assert n_channels_all % 32 == 0
    n_channels = 32
    n_1x1_heads = n_channels_all // n_channels

    def get_init_np():
        """Initializer function for multihead 1x1 parameters using numpy."""
        results = []
        for _ in range(n_1x1_heads):
            random_matrix = np.random.rand(n_channels, n_channels)
            np_w = scipy.linalg.qr(random_matrix)[0].astype("float32")
            np_p, np_l, np_u = scipy.linalg.lu(np_w)
            np_s = np.diag(np_u)
            np_sign_s = np.sign(np_s)[np.newaxis, :]
            np_log_s = np.log(np.abs(np_s))[np.newaxis, :]
            np_u = np.triu(np_u, k=1)
            results.append(
                np.concatenate([np_p, np_l, np_u, np_sign_s, np_log_s],
                               axis=0))
        return tf.convert_to_tensor(np.stack(results, axis=0))

    def get_mask_init():
        ones = tf.ones([n_1x1_heads, n_channels, n_channels], dtype=dtype)
        l_mask = tf.matrix_band_part(ones, -1, 0) - tf.matrix_band_part(
            ones, 0, 0)
        u_mask = tf.matrix_band_part(ones, 0, -1) - tf.matrix_band_part(
            ones, 0, 0)
        return tf.stack([l_mask, u_mask], axis=0)

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        params = tf.get_variable("params",
                                 initializer=get_init_np,
                                 dtype=dtype)
        mask_params = tf.get_variable("mask_params",
                                      initializer=get_mask_init,
                                      dtype=dtype,
                                      trainable=False)

        p = tf.stop_gradient(params[:, :n_channels, :])
        l = params[:, n_channels:2 * n_channels, :]
        u = params[:, 2 * n_channels:3 * n_channels, :]
        sign_s = tf.stop_gradient(params[:, 3 * n_channels, :])
        log_s = params[:, 3 * n_channels + 1, :]

        l_mask = mask_params[0]
        u_mask = mask_params[1]

        l_diag = l * l_mask + (tf.eye(
            n_channels, n_channels, [n_1x1_heads], dtype=dtype))
        u_diag = u * u_mask + (tf.matrix_diag(sign_s * tf.exp(log_s)))
        w = tf.matmul(p, tf.matmul(l_diag, u_diag))

        if multihead_split == "a":
            x = tf.reshape(x, [batch_size, length, n_channels, n_1x1_heads])
            x = tf.transpose(x, [3, 0, 1, 2])
        elif multihead_split == "c":
            x = tf.reshape(x, [batch_size, length, n_1x1_heads, n_channels])
            x = tf.transpose(x, [2, 0, 1, 3])
        else:
            raise ValueError("Multihead split not supported.")
        # [n_1x1_heads, batch_size, length, n_channels]

        if not inverse:
            # [n_1x1_heads, 1, n_channels, n_channels]
            x = tf.matmul(x, w[:, tf.newaxis, :, :])
        else:
            w_inv = tf.matrix_inverse(w)
            x = tf.matmul(x, w_inv[:, tf.newaxis, :, :])

        if multihead_split == "a":
            x = tf.transpose(x, [1, 2, 3, 0])
            x = tf.reshape(x, [batch_size, length, n_channels * n_1x1_heads])
        elif multihead_split == "c":
            x = tf.transpose(x, [1, 2, 0, 3])
            x = tf.reshape(x, [batch_size, length, n_1x1_heads * n_channels])
        else:
            raise ValueError("Multihead split not supported.")

        x_length = tf.reduce_sum(x_mask, -1)
        logabsdet = x_length * tf.reduce_sum(log_s)
        if inverse:
            logabsdet *= -1
    return x, logabsdet
Пример #15
0
 def sigma_tf(self, t, X, Y):  # M x 1, M x D, M x 1
     M = self.M
     D = self.D
     return tf.matrix_diag(tf.ones([M, D]))  # M x D x D
Пример #16
0
 def matrix_diag(x):
     return tf.matrix_diag(x)
def mmr_fc(z_list, y_true, W, n_in, n_hs, n_out, n_rb, n_db, gamma_rb, gamma_db, bs, q):
    """
    Batch-wise implementation of the Maximum Margin Regularizer for fully-connected networks.
    Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss).

    z_list: list with all tensors that correspond to preactivation feature maps
            (in particular, z_list[-1] are logits; see models.MLP for details)
    y_true: one-hot encoded ground truth labels (bs x n_classes)
    W: list with all weight matrices
    n_in: total number of input pixels (e.g. 784 for MNIST)
    n_hs: list of number of hidden units for every hidden layer (e.g. [1024] for FC1)
    n_rb: number of closest region boundaries to take
    n_db: number of closest decision boundaries to take
    gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              certify robustness in)
    gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              be robust in)
    bs: batch size
    q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1)
    """
    eps_num_stabil = 1e-5  # epsilon for numerical stability in the denominator of the distances
    n_hl = len(W) - 1  # number of hidden layers
    y_pred = z_list[-1]

    relus = []
    for i in range(n_hl):
        relu = tf.cast(tf.greater(z_list[i], 0), tf.float32)
        relus.append(tf.expand_dims(relu, 1))

    dist_rb = tf.abs(z_list[0]) / tf.norm(W[0], axis=0, ord=q)  # bs x n_hs[0]  due to broadcasting
    V = tf.reshape(tf.tile(W[0], [bs, 1]), [bs, n_in, n_hs[0]])  # bs x d x n1
    V = V * relus[0]  # element-wise mult using broadcasting, result: bs x d x n_cur
    for i in range(1, n_hl):
        V = calc_v_fc(V, W[i])  # bs x d x n_hs[i]
        new_dist_rb = tf.abs(z_list[i]) / tf.norm(V, axis=1, ord=q)  # bs x n_hs[i]
        dist_rb = tf.concat([dist_rb, new_dist_rb], 1)  # bs x sum(n_hs[1:i])
        V = V * relus[i]  # element-wise mult using broadcasting, result: bs x d x n_cur

    th = zero_out_non_min_distances(dist_rb, n_rb)
    rb_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_rb / gamma_rb), axis=1)

    # decision boundaries
    V_last = calc_v_fc(V, W[-1])
    y_true_diag = tf.matrix_diag(y_true)
    LLK2 = V_last @ y_true_diag  # bs x d x K  @  bs x K x K  =  bs x d x K
    l = tf.reduce_sum(LLK2, axis=2)  # bs x d
    l = tf.tile(l, [1, n_out])  # bs x d x K
    l = tf.reshape(l, [-1, n_out, n_in])  # bs x K x d
    V_argmax = tf.transpose(l, perm=[0, 2, 1])  # bs x d x K
    diff_v = tf.abs(V_last - V_argmax)
    diff_v = diff_v + eps_num_stabil * tf.cast(tf.less(diff_v, eps_num_stabil), tf.float32)
    dist_db_denominator = tf.norm(diff_v, axis=1, ord=q)

    y_pred_diag = tf.expand_dims(y_pred, 1)
    y_pred_correct = y_pred_diag @ y_true_diag  # bs x 1 x K  @  bs x K x K  =  bs x 1 x K
    y_pred_correct = tf.reduce_sum(y_pred_correct, axis=2)  # bs x 1
    y_pred_correct = tf.tile(y_pred_correct, [1, n_out])  # bs x 1 x K
    y_pred_correct = tf.reshape(y_pred_correct, [-1, n_out, 1])  # bs x K x 1
    y_pred_correct = tf.transpose(y_pred_correct, perm=[0, 2, 1])  # bs x 1 x K
    dist_db_numerator = tf.squeeze(y_pred_correct - y_pred_diag, 1)  # bs x K
    dist_db_numerator = dist_db_numerator + 100.0 * y_true  # bs x K

    dist_db = dist_db_numerator / dist_db_denominator + y_true * 2.0 * gamma_db

    th = zero_out_non_min_distances(dist_db, n_db)
    db_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_db / gamma_db), axis=1)
    return rb_term, db_term
def mmr_cnn(z_list, x, y_true, model, n_rb, n_db, gamma_rb, gamma_db, bs, q):
    """
    Batch-wise implementation of the Maximum Margin Regularizer for CNNs as a TensorFlow computational graph.
    Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss).

    z_list: list with all tensors that correspond to preactivation feature maps
            (in particular, z_list[-1] are logits; see models.LeNetSmall for details)
    x: input points (bs x image_height x image_width x image_n_channels)
    y_true: one-hot encoded ground truth labels (bs x n_classes)
    model: models.CNN object that contains a model with its weights, strides, padding, etc
    n_rb: number of closest region boundaries to take
    n_db: number of closest decision boundaries to take
    gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              certify robustness in)
    gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              be robust in)
    bs: batch size
    q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1)
    """

    eps_num_stabil = 1e-5  # epsilon for numerical stability in the denominator of the distances

    # the padding and strides should be the same as in the forward pass conv
    strides, padding = model.strides, model.padding
    y_pred = z_list[-1]
    z_conv, z_fc, relus_conv, relus_fc, W_conv, W_fc = [], [], [], [], [], []
    for w, y in zip(model.W, z_list):  # Depending on the shape we form pre-activation values and their relu switches
        if len(y.shape) == 4:  # if conv layer
            z_conv.append(y)
            relu = tf.cast(tf.greater(y, 0), tf.float32)
            relus_conv.append(tf.expand_dims(relu, 1))
            W_conv.append(w)
        else:
            z_fc.append(y)
            relu = tf.cast(tf.greater(y, 0), tf.float32)
            relus_fc.append(tf.expand_dims(relu, 1))
            W_fc.append(w)

    h_in, w_in, c_in = int(x.shape[1]), int(x.shape[2]), int(x.shape[3])
    n_in = h_in * w_in * c_in
    n_out = y_true.shape[1]

    # z[0]: bs x h_next x w_next x n_next,  W[0]: h_filter x w_filter x n_prev x n_next
    w_matrix = tf.reshape(W_conv[0], [-1, int(W_conv[0].shape[-1])])  # h_filter*w_filter*n_col x n_next
    denom = tf.norm(w_matrix, axis=0, ord=q, keep_dims=True)  # n_next
    dist_rb = tf.abs(z_conv[0]) / denom  # bs x h_next x w_next x n_next
    dist_rb = tf.reshape(dist_rb, [bs, int(z_conv[0].shape[1]*z_conv[0].shape[2]*z_conv[0].shape[3])])  # bs x h_next*w_next*n_next

    # We need to get the conv matrix. Instead of using loops to contruct such matrix, we can apply W[0] conv filter
    # to a reshaped identity matrix. Then we duplicate bs times the resulting tensor.
    identity_input_fm = tf.reshape(tf.eye(n_in, n_in), [1, n_in, h_in, w_in, c_in])
    V = calc_v_conv(identity_input_fm, W_conv[0], strides[0], padding)  # 1 x d x h_next x w_next x c_next
    V = tf.tile(V, [bs, 1, 1, 1, 1])  # bs x d x h_next x w_next x c_next
    V = V * relus_conv[0]
    for i in range(1, len(z_conv)):
        V = calc_v_conv(V, W_conv[i], strides[i], padding)  # bs x d x h_next x w_next x c_next
        V_stable = V + eps_num_stabil * tf.cast(tf.less(tf.abs(V), eps_num_stabil), tf.float32)  # note: +eps would also work
        new_dist_rb = tf.abs(z_conv[i]) / tf.norm(V_stable, axis=1, ord=q)  # bs x h_next x w_next x c_next
        new_dist_rb = tf.reshape(new_dist_rb, [bs, z_conv[i].shape[1]*z_conv[i].shape[2]*z_conv[i].shape[3]])  # bs x h_next*w_next*c_next
        dist_rb = tf.concat([dist_rb, new_dist_rb], 1)  # bs x sum(n_neurons[1:i])
        V = V * relus_conv[i]  # element-wise mult using broadcasting, result: bs x d x h_cur x w_cur x c_cur

    # Flattening after the last conv layer
    V = tf.reshape(V, [bs, n_in, V.shape[2] * V.shape[3] * V.shape[4]])  # bs x d x h_prev*w_prev*c_prev

    for i in range(len(z_fc) - 1):  # the last layer requires special handling
        V = calc_v_fc(V, W_fc[i])  # bs x d x n_hs[i]
        V_stable = V + eps_num_stabil * tf.cast(tf.less(tf.abs(V), eps_num_stabil), tf.float32)
        new_dist_rb = tf.abs(z_fc[i]) / tf.norm(V_stable, axis=1, ord=q)  # bs x n_hs[i]
        dist_rb = tf.concat([dist_rb, new_dist_rb], 1)  # bs x sum(n_hs[1:i])
        V = V * relus_fc[i]  # element-wise mult using broadcasting, result: bs x d x n_cur

    th = zero_out_non_min_distances(dist_rb, n_rb)
    rb_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_rb / gamma_rb), axis=1)

    # decision boundaries
    V = calc_v_fc(V, W_fc[-1])
    y_true_diag = tf.matrix_diag(y_true)
    LLK2 = V @ y_true_diag  # bs x d x K  @  bs x K x K  =  bs x d x K
    l = tf.reduce_sum(LLK2, axis=2)  # bs x d
    l = tf.tile(l, [1, n_out])  # bs x d x K
    l = tf.reshape(l, [-1, n_out, n_in])  # bs x K x d
    V_argmax = tf.transpose(l, perm=[0, 2, 1])  # bs x d x K
    diff_v = tf.abs(V - V_argmax)
    diff_v = diff_v + eps_num_stabil * tf.cast(tf.less(diff_v, eps_num_stabil), tf.float32)
    dist_db_denominator = tf.norm(diff_v, axis=1, ord=q)

    y_pred_diag = tf.expand_dims(y_pred, 1)
    y_pred_correct = y_pred_diag @ y_true_diag  # bs x 1 x K  @  bs x K x K  =  bs x 1 x K
    y_pred_correct = tf.reduce_sum(y_pred_correct, axis=2)  # bs x 1
    y_pred_correct = tf.tile(y_pred_correct, [1, n_out])  # bs x 1 x K
    y_pred_correct = tf.reshape(y_pred_correct, [-1, n_out, 1])  # bs x K x 1
    y_pred_correct = tf.transpose(y_pred_correct, perm=[0, 2, 1])  # bs x 1 x K
    dist_db_numerator = tf.squeeze(y_pred_correct - y_pred_diag, 1)  # bs x K
    dist_db_numerator = dist_db_numerator + 100.0 * y_true  # bs x K

    dist_db = dist_db_numerator / dist_db_denominator + y_true * 2.0 * gamma_db

    th = zero_out_non_min_distances(dist_db, n_db)
    db_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_db / gamma_db), axis=1)
    return rb_term, db_term