예제 #1
0
    def __init__(self, methodname):
        """Initialize the test class."""
        super().__init__(methodname)

        self.ntrain = 100
        self.nclass = 5
        self.ndim = 5

        # Generate random data.
        np.random.seed(1234)
        self.data = np.random.rand(self.ntrain, self.ndim) * 10
        self.labels = np.random.randint(self.nclass, size=self.ntrain)
        self.labels = (np.arange(self.nclass) == self.labels[:, None]).astype(
            'f')  # make one-hot

        # Set model, optimizer and loss.
        self.model = DNNet(layer_sizes=[self.ndim, self.nclass],
                           activation=objax.functional.softmax)
        self.model_vars = self.model.vars()

        def loss_function(x, y):
            logit = self.model(x)
            loss = ((y - logit)**2).mean(1).mean(0)
            return loss, {'loss': loss}

        self.loss = loss_function
예제 #2
0
    def test_typical_training_loop(self):
        # Define model and optimizer
        model = DNNet((32, 10), objax.functional.leaky_relu)
        opt = objax.optimizer.Momentum(model.vars(), nesterov=True)

        # Predict op
        predict_op = lambda x: objax.functional.softmax(model(x, training=False))

        self.assertDictEqual(objax.util.find_used_variables(predict_op),
                             model.vars(scope='model.'))

        # Loss function
        def loss(x, label):
            logit = model(x, training=True)
            xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
            return xe_loss

        self.assertDictEqual(objax.util.find_used_variables(loss),
                             model.vars(scope='model.'))

        # Gradients and loss function
        loss_gv = objax.GradValues(loss, objax.util.find_used_variables(loss))

        def train_op(x, y, learning_rate):
            grads, loss = loss_gv(x, y)
            opt(learning_rate, grads)
            return loss

        self.assertDictEqual(objax.util.find_used_variables(train_op),
                             {**model.vars(scope='loss_gv.model.'), **opt.vars(scope='opt.')})
예제 #3
0
class TestPrivateGradValues(unittest.TestCase):
    def __init__(self, methodname):
        """Initialize the test class."""
        super().__init__(methodname)

        self.ntrain = 100
        self.nclass = 5
        self.ndim = 5

        # Generate random data.
        np.random.seed(1234)
        self.data = np.random.rand(self.ntrain, self.ndim) * 10
        self.labels = np.random.randint(self.nclass, size=self.ntrain)
        self.labels = (np.arange(self.nclass) == self.labels[:, None]).astype(
            'f')  # make one-hot

        # Set model, optimizer and loss.
        self.model = DNNet(layer_sizes=[self.ndim, self.nclass],
                           activation=objax.functional.softmax)
        self.model_vars = self.model.vars()

        def loss_function(x, y):
            logit = self.model(x)
            loss = ((y - logit)**2).mean(1).mean(0)
            return loss, {'loss': loss}

        self.loss = loss_function

    def test_private_gradvalues_compare_nonpriv(self):
        """Test if PrivateGradValues without clipping / noise is the same as non-private GradValues."""
        l2_norm_clip = 1e10
        noise_multiplier = 0

        for use_norm_accumulation in [True, False]:
            for microbatch in [1, 10, self.ntrain]:
                gv_priv = objax.Jit(
                    objax.privacy.dpsgd.PrivateGradValues(
                        self.loss,
                        self.model_vars,
                        noise_multiplier,
                        l2_norm_clip,
                        microbatch,
                        batch_axis=(0, 0),
                        use_norm_accumulation=use_norm_accumulation))
                gv = objax.GradValues(self.loss, self.model_vars)
                g_priv, v_priv = gv_priv(self.data, self.labels)
                g, v = gv(self.data, self.labels)

                # Check the shape of the gradient.
                self.assertEqual(g_priv[0].shape, tuple([self.nclass]))
                self.assertEqual(g_priv[1].shape,
                                 tuple([self.ndim, self.nclass]))

                # Check if the private gradient is similar to the non-private gradient.
                np.testing.assert_allclose(g[0], g_priv[0], atol=1e-7)
                np.testing.assert_allclose(g[1], g_priv[1], atol=1e-7)
                np.testing.assert_allclose(v_priv[0],
                                           self.loss(self.data,
                                                     self.labels)[0],
                                           atol=1e-7)

    def test_private_gradvalues_clipping(self):
        """Test if the gradient norm is within l2_norm_clip."""
        noise_multiplier = 0
        acceptable_float_error = 1e-8
        for use_norm_accumulation in [True, False]:
            for microbatch in [1, 10, self.ntrain]:
                for l2_norm_clip in [0, 1e-2, 1e-1, 1.0]:
                    gv_priv = objax.Jit(
                        objax.privacy.dpsgd.PrivateGradValues(
                            self.loss,
                            self.model_vars,
                            noise_multiplier,
                            l2_norm_clip,
                            microbatch,
                            batch_axis=(0, 0),
                            use_norm_accumulation=use_norm_accumulation))
                    g_priv, v_priv = gv_priv(self.data, self.labels)
                    # Get the actual squared norm of the gradient.
                    g_normsquared = sum([np.sum(g**2) for g in g_priv])
                    self.assertLessEqual(
                        g_normsquared,
                        l2_norm_clip**2 + acceptable_float_error)
                    np.testing.assert_allclose(v_priv[0],
                                               self.loss(
                                                   self.data, self.labels)[0],
                                               atol=1e-7)

    def test_private_gradvalues_noise(self):
        """Test if the noise std is around expected."""
        runs = 100
        alpha = 0.0001
        for use_norm_accumulation in [True, False]:
            for microbatch in [1, 10, self.ntrain]:
                for noise_multiplier in [0.1, 10.0]:
                    for l2_norm_clip in [0.01, 0.1]:
                        gv_priv = objax.Jit(
                            objax.privacy.dpsgd.PrivateGradValues(
                                self.loss,
                                self.model_vars,
                                noise_multiplier,
                                l2_norm_clip,
                                microbatch,
                                batch_axis=(0, 0),
                                use_norm_accumulation=use_norm_accumulation))
                        # Repeat the run and collect all gradients.
                        g_privs = []
                        for i in range(runs):
                            g_priv, v_priv = gv_priv(self.data, self.labels)
                            g_privs.append(
                                np.concatenate(
                                    [g_n.reshape(-1) for g_n in g_priv]))
                            np.testing.assert_allclose(v_priv[0],
                                                       self.loss(
                                                           self.data,
                                                           self.labels)[0],
                                                       atol=1e-7)
                        g_privs = np.array(g_privs)

                        # Compute empirical std and expected std.
                        std_empirical = np.std(g_privs, axis=0, ddof=1)
                        std_theoretical = l2_norm_clip * noise_multiplier / (
                            self.ntrain // microbatch)

                        # Conduct chi-square test for correct expected standard
                        # deviation.
                        chi2_value = (
                            runs - 1) * std_empirical**2 / std_theoretical**2
                        chi2_cdf = chi2.cdf(chi2_value, runs - 1)
                        self.assertTrue(
                            np.all(alpha <= chi2_cdf)
                            and np.all(chi2_cdf <= 1.0 - alpha))

                        # Conduct chi-square test for incorrect expected standard
                        # deviations: expect failure.
                        chi2_value = (runs - 1) * std_empirical**2 / (
                            1.25 * std_theoretical)**2
                        chi2_cdf = chi2.cdf(chi2_value, runs - 1)
                        self.assertFalse(
                            np.all(alpha <= chi2_cdf)
                            and np.all(chi2_cdf <= 1.0 - alpha))

                        chi2_value = (runs - 1) * std_empirical**2 / (
                            0.75 * std_theoretical)**2
                        chi2_cdf = chi2.cdf(chi2_value, runs - 1)
                        self.assertFalse(
                            np.all(alpha <= chi2_cdf)
                            and np.all(chi2_cdf <= 1.0 - alpha))
예제 #4
0
flat_test_images = np.reshape(
    data['test']['image'].transpose(0, 3, 1, 2) / 127.5 - 1,
    (test_size, image_size))
test = EasyDict(image=flat_test_images, label=data['test']['label'])
train = EasyDict(image=flat_train_images, label=data['train']['label'])
del data

# Settings
lr = 0.0002
batch = 64
num_train_epochs = 40
dnn_layer_sizes = image_size, 128, 10
logdir = f'experiments/classify/img/mnist/filters{dnn_layer_sizes}'

# Model
model = DNNet(dnn_layer_sizes, leaky_relu)
model_ema = objax.optimizer.ExponentialMovingAverageModule(model,
                                                           momentum=0.999)
opt = objax.optimizer.Adam(model.vars())


@objax.Function.with_vars(model.vars())
def loss(x, label):
    logit = model(x)
    return objax.functional.loss.cross_entropy_logits(logit, label).mean()


gv = objax.GradValues(loss, model.vars())


@objax.Function.with_vars(model.vars() + gv.vars() + opt.vars() +