def test_grad_logistic(self): """Test if gradient has the correct value for logistic regression.""" # Set data ndim = 2 data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, -1.0, 1.0, -1.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.ones(ndim)) def loss(x, y): xyw = jn.dot(x * np.tile(y, (ndim, 1)).transpose(), w.value) return jn.log(jn.exp(-xyw) + 1).mean(0) grad = objax.Grad(loss, objax.VarCollection({'w': w})) g = grad(data, labels) self.assertEqual(g[0].shape, tuple([ndim])) xw = np.dot(data, w.value) g_expect_w = -(data * np.tile(labels / (1 + np.exp(labels * xw)), (ndim, 1)).transpose()).mean(0) np.testing.assert_allclose(g[0], g_expect_w, atol=1e-7)
def test_on_conv_transpose_2d_padding(self): """ Pass an input through a two-by-two convolution filter with padding and test the shape and contents of the output. """ # Channels/Colors, #filters, filter_size (square) conv_filter = objax.nn.ConvTranspose2D(1, 1, 2, padding=objax.ConvPadding.SAME) weights = objax.TrainVar( jn.array([[[[1., 2.], [3., 4.]]]]).transpose((2, 3, 0, 1))) conv_filter.w = weights image = jn.array([[[[2., 1., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.]]]]) # NCHW: Batch, Channels/Colors, Height, Width features = conv_filter(image) expected_features = jn.array([[[[2., 5., 5., 10.], [11., 27., 32., 46.], [24., 66., 76., 86.], [40., 106., 116., 126.]]]]) self.assertEqual(features.shape, (1, 1, 4, 4)) self.assertTrue(jn.array_equal(features, expected_features))
def test_grad_linear_and_inputs(self): """Test if gradient of inputs and variables has the correct values for linear regression.""" # Set data data = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [-10.0, 9.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.array([2, 3], jn.float32)) b = objax.TrainVar(jn.array([1], jn.float32)) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() expect_gw = [37.25, 69.0] expect_gb = [13.75] expect_gx = [[4.0, 6.0], [8.5, 12.75], [13.0, 19.5], [2.0, 3.0]] expect_gy = [-2.0, -4.25, -6.5, -1.0] grad0 = objax.Grad(loss, objax.VarCollection({ 'w': w, 'b': b }), input_argnums=(0, )) g = grad0(data, labels) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) grad1 = objax.Grad(loss, objax.VarCollection({ 'w': w, 'b': b }), input_argnums=(1, )) g = grad1(data, labels) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gw) self.assertEqual(g[2].tolist(), expect_gb) grad01 = objax.Grad(loss, objax.VarCollection({ 'w': w, 'b': b }), input_argnums=(0, 1)) g = grad01(data, labels) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) grad10 = objax.Grad(loss, objax.VarCollection({ 'w': w, 'b': b }), input_argnums=(1, 0)) g = grad10(data, labels) self.assertEqual(g[0].tolist(), expect_gy) self.assertEqual(g[1].tolist(), expect_gx) self.assertEqual(g[2].tolist(), expect_gw) self.assertEqual(g[3].tolist(), expect_gb) grad10 = objax.Grad(loss, None, input_argnums=(0, 1)) g = grad10(data, labels) self.assertEqual(g[0].tolist(), expect_gx) self.assertEqual(g[1].tolist(), expect_gy)
def test_tensors(self): vshared = objax.TrainVar(jn.ones(1)) vc = objax.VarCollection([('a', objax.TrainVar(jn.zeros(1))), ('b', vshared)]) vc += objax.VarCollection([('c', vshared)]) self.assertEqual(len(vc.tensors()), 2) self.assertEqual([x.sum() for x in vc.tensors()], [0, 1])
def test_assign(self): vc = objax.VarCollection({'a': objax.TrainVar(jn.zeros(1))}) vc['b'] = objax.TrainVar(jn.ones(1)) self.assertEqual(len(vc), 2) self.assertEqual(vc['a'].value.sum(), 0) self.assertEqual(vc['b'].value.sum(), 1)
def __init__(self, n): self.v = objax.TrainVar(jn.zeros(n))
def __init__(self, nin, nout): self.v1 = objax.TrainVar(jn.ones((nin, nout))) self.var_list = [ objax.TrainVar(jn.zeros(nin)), objax.TrainVar(jn.zeros(nout)) ]
def test_replicate_shape_assert(self): """Test replicating variable shapes does not assert""" vc = objax.VarCollection({'var': objax.TrainVar(jn.zeros(5))}) with vc.replicate(): self.assertEqual(len(vc['var'].value.shape), 2) self.assertEqual(vc['var'].value.shape[-1], 5)
def test_jax_duck_typing_get_item(self): v = objax.TrainVar( jn.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=jn.float32)) np.testing.assert_allclose(v.value[0, 1], v[0, 1]) np.testing.assert_allclose(v.value[1, :], v[1, :])
def build(self): # inputs in [0, 255] self.preprocess = preprocess self.conv1_1 = objax.nn.Conv2D(nin=3, nout=64, k=3) self.relu1_1 = objax.functional.relu self.conv1_2 = objax.nn.Conv2D(nin=64, nout=64, k=3) self.relu1_2 = objax.functional.relu self.pool1 = max_pool_2d self.conv2_1 = objax.nn.Conv2D(nin=64, nout=128, k=3) self.relu2_1 = objax.functional.relu self.conv2_2 = objax.nn.Conv2D(nin=128, nout=128, k=3) self.relu2_2 = objax.functional.relu self.pool2 = max_pool_2d self.conv3_1 = objax.nn.Conv2D(nin=128, nout=256, k=3) self.relu3_1 = objax.functional.relu self.conv3_2 = objax.nn.Conv2D(nin=256, nout=256, k=3) self.relu3_2 = objax.functional.relu self.conv3_3 = objax.nn.Conv2D(nin=256, nout=256, k=3) self.relu3_3 = objax.functional.relu self.conv3_4 = objax.nn.Conv2D(nin=256, nout=256, k=3) self.relu3_4 = objax.functional.relu self.pool3 = max_pool_2d self.conv4_1 = objax.nn.Conv2D(nin=256, nout=512, k=3) self.relu4_1 = objax.functional.relu self.conv4_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu4_2 = objax.functional.relu self.conv4_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu4_3 = objax.functional.relu self.conv4_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu4_4 = objax.functional.relu self.pool4 = max_pool_2d self.conv5_1 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_1 = objax.functional.relu self.conv5_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_2 = objax.functional.relu self.conv5_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_3 = objax.functional.relu self.conv5_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_4 = objax.functional.relu self.pool5 = max_pool_2d self.flatten = objax.functional.flatten self.fc6 = objax.nn.Linear(nin=512 * 7 * 7, nout=4096) self.relu6 = objax.functional.relu self.fc7 = objax.nn.Linear(nin=4096, nout=4096) self.relu7 = objax.functional.relu self.fc8 = objax.nn.Linear(nin=4096, nout=1000) if self.pretrained: for it in self.data_dict: if it.startswith('conv'): conv = getattr(self, it) kernel, bias = self.data_dict[it] conv.w = objax.TrainVar(jn.array(kernel)) conv.b = objax.TrainVar(jn.array(bias[:, None, None])) setattr(self, it, conv) elif it.startswith('fc'): linear = getattr(self, it) kernel, bias = self.data_dict[it] if it == 'fc6': kernel = kernel.reshape([7, 7, 512, -1]).transpose( (2, 0, 1, 3)).reshape([512 * 7 * 7, -1]) linear.w = objax.TrainVar(jn.array(kernel)) linear.b = objax.TrainVar(jn.array(bias)) setattr(self, it, linear) ops = [ self.conv1_1, self.relu1_1, self.conv1_2, self.relu1_2, self.pool1, self.conv2_1, self.relu2_1, self.conv2_2, self.relu2_2, self.pool2, self.conv3_1, self.relu3_1, self.conv3_2, self.relu3_2, self.conv3_3, self.relu3_3, self.conv3_4, self.relu3_4, self.pool3, self.conv4_1, self.relu4_1, self.conv4_2, self.relu4_2, self.conv4_3, self.relu4_3, self.conv4_4, self.relu4_4, self.pool4, self.conv5_1, self.relu5_1, self.conv5_2, self.relu5_2, self.conv5_3, self.relu5_3, self.conv5_4, self.relu5_4, self.pool5, self.flatten, self.fc6, self.relu6, self.fc7, self.relu7, self.fc8 ] return ops
def __init__(self, n_channels: int): self.weight = objax.TrainVar( objax.nn.init.orthogonal((n_channels, n_channels)))
def __init__(self, n_channels: int, n_reflections: int = 10): self.weight = objax.TrainVar( objax.nn.init.orthogonal((n_reflections, n_channels))) self.weight_init = np.eye(n_channels)
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unitests for automatic variable tracing.""" import unittest import numpy as np import jax.numpy as jn import objax from objax.zoo.dnnet import DNNet global_w = objax.TrainVar(jn.zeros(5)) global_b = objax.TrainVar(jn.zeros(1)) global_m = objax.nn.Sequential([objax.nn.Conv2D(2, 4, 3), objax.nn.BatchNorm2D(4)]) class TestTracing(unittest.TestCase): """Unit tests for variable tracing using.""" def test_function_global_vars(self): def loss(x, y): pred = jn.dot(x, global_w.value) + global_b.value return 0.5 * ((y - pred) ** 2).mean() vc = objax.util.find_used_variables(loss) self.assertDictEqual(vc, {'global_w': global_w, 'global_b': global_b})