Exemplo n.º 1
0
    def test_grad_logistic(self):
        """Test if gradient has the correct value for logistic regression."""
        # Set data
        ndim = 2
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, -1.0, 1.0, -1.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.ones(ndim))

        def loss(x, y):
            xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value)
            return jn.log(jn.exp(-xyw) + 1).mean(0)

        grad = objax.Grad(loss, objax.VarCollection({'w': w}))
        g = grad(data, labels)

        self.assertEqual(g[0].shape, tuple([ndim]))

        xw = np.dot(data, w.value)
        g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)),
                                      (ndim, 1)).transpose()).mean(0)
        np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7)
Exemplo n.º 2
0
    def test_on_conv_transpose_2d_padding(self):
        """
        Pass an input through a two-by-two convolution filter with padding and
        test the shape and contents of the output.
        """

        # Channels/Colors, #filters, filter_size (square)
        conv_filter = objax.nn.ConvTranspose2D(1,
                                               1,
                                               2,
                                               padding=objax.ConvPadding.SAME)
        weights = objax.TrainVar(
            jn.array([[[[1., 2.], [3., 4.]]]]).transpose((2, 3, 0, 1)))
        conv_filter.w = weights
        image = jn.array([[[[2., 1., 3., 4.], [5., 6., 7., 8.],
                            [9., 10., 11., 12.], [13., 14., 15., 16.]]]])
        # NCHW: Batch, Channels/Colors, Height, Width
        features = conv_filter(image)
        expected_features = jn.array([[[[2., 5., 5., 10.],
                                        [11., 27., 32., 46.],
                                        [24., 66., 76., 86.],
                                        [40., 106., 116., 126.]]]])
        self.assertEqual(features.shape, (1, 1, 4, 4))
        self.assertTrue(jn.array_equal(features, expected_features))
Exemplo n.º 3
0
    def test_grad_linear_and_inputs(self):
        """Test if gradient of inputs and variables has the correct values for linear regression."""
        # Set data
        data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.array([2, 3], jn.float32))
        b = objax.TrainVar(jn.array([1], jn.float32))

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        expect_gw = [37.25, 69.0]
        expect_gb = [13.75]
        expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]]
        expect_gy = [-2.0, -4.25, -6.5, -1.0]

        grad0 = objax.Grad(loss,
                           objax.VarCollection({
                               'w': w,
                               'b': b
                           }),
                           input_argnums=(0, ))
        g = grad0(data, labels)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        grad1 = objax.Grad(loss,
                           objax.VarCollection({
                               'w': w,
                               'b': b
                           }),
                           input_argnums=(1, ))
        g = grad1(data, labels)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gw)
        self.assertEqual(g[2].tolist(), expect_gb)

        grad01 = objax.Grad(loss,
                            objax.VarCollection({
                                'w': w,
                                'b': b
                            }),
                            input_argnums=(0, 1))
        g = grad01(data, labels)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        grad10 = objax.Grad(loss,
                            objax.VarCollection({
                                'w': w,
                                'b': b
                            }),
                            input_argnums=(1, 0))
        g = grad10(data, labels)
        self.assertEqual(g[0].tolist(), expect_gy)
        self.assertEqual(g[1].tolist(), expect_gx)
        self.assertEqual(g[2].tolist(), expect_gw)
        self.assertEqual(g[3].tolist(), expect_gb)

        grad10 = objax.Grad(loss, None, input_argnums=(0, 1))
        g = grad10(data, labels)
        self.assertEqual(g[0].tolist(), expect_gx)
        self.assertEqual(g[1].tolist(), expect_gy)
Exemplo n.º 4
0
 def test_tensors(self):
     vshared = objax.TrainVar(jn.ones(1))
     vc = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1))), ('b', vshared)])
     vc += objax.VarCollection([('c', vshared)])
     self.assertEqual(len(vc.tensors()), 2)
     self.assertEqual([x.sum() for x in vc.tensors()], [0, 1])
Exemplo n.º 5
0
 def test_assign(self):
     vc = objax.VarCollection({'a': objax.TrainVar(jn.zeros(1))})
     vc['b'] = objax.TrainVar(jn.ones(1))
     self.assertEqual(len(vc), 2)
     self.assertEqual(vc['a'].value.sum(), 0)
     self.assertEqual(vc['b'].value.sum(), 1)
Exemplo n.º 6
0
 def __init__(self, n):
     self.v = objax.TrainVar(jn.zeros(n))
Exemplo n.º 7
0
 def __init__(self, nin, nout):
     self.v1 = objax.TrainVar(jn.ones((nin, nout)))
     self.var_list = [
         objax.TrainVar(jn.zeros(nin)),
         objax.TrainVar(jn.zeros(nout))
     ]
Exemplo n.º 8
0
 def test_replicate_shape_assert(self):
     """Test replicating variable shapes does not assert"""
     vc = objax.VarCollection({'var': objax.TrainVar(jn.zeros(5))})
     with vc.replicate():
         self.assertEqual(len(vc['var'].value.shape), 2)
         self.assertEqual(vc['var'].value.shape[-1], 5)
Exemplo n.º 9
0
 def test_jax_duck_typing_get_item(self):
     v = objax.TrainVar(
         jn.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
                  dtype=jn.float32))
     np.testing.assert_allclose(v.value[0, 1], v[0, 1])
     np.testing.assert_allclose(v.value[1, :], v[1, :])
Exemplo n.º 10
0
    def build(self):
        # inputs in [0, 255]
        self.preprocess = preprocess
        self.conv1_1 = objax.nn.Conv2D(nin=3, nout=64, k=3)
        self.relu1_1 = objax.functional.relu
        self.conv1_2 = objax.nn.Conv2D(nin=64, nout=64, k=3)
        self.relu1_2 = objax.functional.relu
        self.pool1 = max_pool_2d

        self.conv2_1 = objax.nn.Conv2D(nin=64, nout=128, k=3)
        self.relu2_1 = objax.functional.relu
        self.conv2_2 = objax.nn.Conv2D(nin=128, nout=128, k=3)
        self.relu2_2 = objax.functional.relu
        self.pool2 = max_pool_2d

        self.conv3_1 = objax.nn.Conv2D(nin=128, nout=256, k=3)
        self.relu3_1 = objax.functional.relu
        self.conv3_2 = objax.nn.Conv2D(nin=256, nout=256, k=3)
        self.relu3_2 = objax.functional.relu
        self.conv3_3 = objax.nn.Conv2D(nin=256, nout=256, k=3)
        self.relu3_3 = objax.functional.relu
        self.conv3_4 = objax.nn.Conv2D(nin=256, nout=256, k=3)
        self.relu3_4 = objax.functional.relu
        self.pool3 = max_pool_2d

        self.conv4_1 = objax.nn.Conv2D(nin=256, nout=512, k=3)
        self.relu4_1 = objax.functional.relu
        self.conv4_2 = objax.nn.Conv2D(nin=512, nout=512, k=3)
        self.relu4_2 = objax.functional.relu
        self.conv4_3 = objax.nn.Conv2D(nin=512, nout=512, k=3)
        self.relu4_3 = objax.functional.relu
        self.conv4_4 = objax.nn.Conv2D(nin=512, nout=512, k=3)
        self.relu4_4 = objax.functional.relu
        self.pool4 = max_pool_2d

        self.conv5_1 = objax.nn.Conv2D(nin=512, nout=512, k=3)
        self.relu5_1 = objax.functional.relu
        self.conv5_2 = objax.nn.Conv2D(nin=512, nout=512, k=3)
        self.relu5_2 = objax.functional.relu
        self.conv5_3 = objax.nn.Conv2D(nin=512, nout=512, k=3)
        self.relu5_3 = objax.functional.relu
        self.conv5_4 = objax.nn.Conv2D(nin=512, nout=512, k=3)
        self.relu5_4 = objax.functional.relu
        self.pool5 = max_pool_2d

        self.flatten = objax.functional.flatten
        self.fc6 = objax.nn.Linear(nin=512 * 7 * 7, nout=4096)
        self.relu6 = objax.functional.relu
        self.fc7 = objax.nn.Linear(nin=4096, nout=4096)
        self.relu7 = objax.functional.relu
        self.fc8 = objax.nn.Linear(nin=4096, nout=1000)

        if self.pretrained:
            for it in self.data_dict:
                if it.startswith('conv'):
                    conv = getattr(self, it)
                    kernel, bias = self.data_dict[it]
                    conv.w = objax.TrainVar(jn.array(kernel))
                    conv.b = objax.TrainVar(jn.array(bias[:, None, None]))
                    setattr(self, it, conv)
                elif it.startswith('fc'):
                    linear = getattr(self, it)
                    kernel, bias = self.data_dict[it]
                    if it == 'fc6':
                        kernel = kernel.reshape([7, 7, 512, -1]).transpose(
                            (2, 0, 1, 3)).reshape([512 * 7 * 7, -1])
                    linear.w = objax.TrainVar(jn.array(kernel))
                    linear.b = objax.TrainVar(jn.array(bias))
                    setattr(self, it, linear)

        ops = [
            self.conv1_1, self.relu1_1, self.conv1_2, self.relu1_2, self.pool1,
            self.conv2_1, self.relu2_1, self.conv2_2, self.relu2_2, self.pool2,
            self.conv3_1, self.relu3_1, self.conv3_2, self.relu3_2,
            self.conv3_3, self.relu3_3, self.conv3_4, self.relu3_4, self.pool3,
            self.conv4_1, self.relu4_1, self.conv4_2, self.relu4_2,
            self.conv4_3, self.relu4_3, self.conv4_4, self.relu4_4, self.pool4,
            self.conv5_1, self.relu5_1, self.conv5_2, self.relu5_2,
            self.conv5_3, self.relu5_3, self.conv5_4, self.relu5_4, self.pool5,
            self.flatten, self.fc6, self.relu6, self.fc7, self.relu7, self.fc8
        ]

        return ops
Exemplo n.º 11
0
 def __init__(self, n_channels: int):
     self.weight = objax.TrainVar(
         objax.nn.init.orthogonal((n_channels, n_channels)))
Exemplo n.º 12
0
    def __init__(self, n_channels: int, n_reflections: int = 10):

        self.weight = objax.TrainVar(
            objax.nn.init.orthogonal((n_reflections, n_channels)))
        self.weight_init = np.eye(n_channels)
Exemplo n.º 13
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unitests for automatic variable tracing."""

import unittest

import numpy as np
import jax.numpy as jn
import objax
from objax.zoo.dnnet import DNNet


global_w = objax.TrainVar(jn.zeros(5))
global_b = objax.TrainVar(jn.zeros(1))

global_m = objax.nn.Sequential([objax.nn.Conv2D(2, 4, 3), objax.nn.BatchNorm2D(4)])


class TestTracing(unittest.TestCase):
    """Unit tests for variable tracing using."""

    def test_function_global_vars(self):
        def loss(x, y):
            pred = jn.dot(x, global_w.value) + global_b.value
            return 0.5 * ((y - pred) ** 2).mean()

        vc = objax.util.find_used_variables(loss)
        self.assertDictEqual(vc, {'global_w': global_w, 'global_b': global_b})