def __init__(self,
                 num_models,
                 dense_kernel_size=32,
                 embedding_dim=32,
                 seed=0,
                 logit_temp=1.0,
                 orthogonal_init=True):

        if num_models <= 1:
            raise Exception("requires at least two models")

        self.num_models = num_models
        self.logit_temp = logit_temp

        key = random.PRNGKey(seed)
        subkeys = random.split(key, 8)

        # conv stack kernels and biases
        if orthogonal_init:
            initialiser = orthogonal
        else:
            initialiser = he_normal
        self.conv_kernels = objax.ModuleList()
        self.conv_biases = objax.ModuleList()
        input_channels = 3
        for i, output_channels in enumerate([32, 64, 64, 64, 64, 64]):
            self.conv_kernels.append(
                TrainVar(initialiser()(
                    subkeys[i],
                    (num_models, 3, 3, input_channels, output_channels))))
            self.conv_biases.append(
                TrainVar(jnp.zeros((num_models, output_channels))))
            input_channels = output_channels

        # dense kernels and biases
        self.dense_kernels = TrainVar(initialiser()(
            subkeys[6],
            (num_models, 1, 1, output_channels, dense_kernel_size)))
        self.dense_biases = TrainVar(jnp.zeros(
            (num_models, dense_kernel_size)))

        # embeddings kernel; no bias or non linearity.
        if orthogonal_init:
            initialiser = orthogonal
        else:
            initialiser = glorot_normal
        self.embedding_kernels = TrainVar(initialiser()(
            subkeys[7], (num_models, 1, 1, dense_kernel_size, embedding_dim)))
Пример #2
0
    def test_gradvalues_constant(self):
        """Test if constants are preserved."""
        # Set data
        ndim = 1
        data = np.array([[1.0], [3.0], [5.0], [-10.0]])
        labels = np.array([1.0, 2.0, 3.0, 4.0])

        # Set model parameters for linear regression.
        w = objax.TrainVar(jn.zeros(ndim))
        b = objax.TrainVar(jn.ones(1))
        m = objax.ModuleList([w, b])

        def loss(x, y):
            pred = jn.dot(x, w.value) + b.value
            return 0.5 * ((y - pred)**2).mean()

        # We are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.GradValues(loss, objax.VarCollection({'w': w}))
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])

        # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes.
        gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})),
                       m.vars())
        g_old, v_old = gv(data, labels)
        b.assign(-b.value)
        g_new, v_new = gv(data, labels)
        self.assertNotEqual(g_old[0][0], g_new[0][0])
Пример #3
0
    def test_ema(self):
        eps = 1e-6
        orig_value_expect = np.array([100.0, -1.0])
        for m in [i / 10.0 for i in range(1, 10)]:
            x = objax.ModuleList([objax.TrainVar(jn.array(orig_value_expect))])
            ema = objax.optimizer.ExponentialMovingAverage(x.vars(),
                                                           momentum=m,
                                                           debias=True,
                                                           eps=eps)
            ema()
            ema_value_expect = (1 - m) * orig_value_expect / (1 -
                                                              (1 - eps) * m)

            def get_tensors():
                return x.vars().tensors()

            get_tensors_ema = ema.replace_vars(get_tensors)
            ema_value = get_tensors_ema()
            orig_value = get_tensors()

            np.testing.assert_allclose(ema_value[0],
                                       ema_value_expect,
                                       rtol=1e-6)
            np.testing.assert_allclose(orig_value[0],
                                       orig_value_expect,
                                       rtol=1e-6)
Пример #4
0
 def test_weight_sharing(self):
     """Check weight sharing."""
     m = objax.ModuleList([objax.TrainVar(jn.ones(3))])
     all_vars = m.vars('m1') + m.vars('m2')
     self.assertEqual(len(all_vars), 2)
     self.assertEqual(len(list(all_vars)), 1)
     self.assertEqual(len(all_vars.tensors()), 1)
     all_vars.assign([x + 1 for x in all_vars.tensors()])
     self.assertEqual(m[0].value.sum(), 6)
Пример #5
0
 def test_module_list(self):
     """Unit test for objax.ModuleList."""
     m1 = SimpleModule(3, 5)
     m2 = SimpleModule(5, 7)
     module_list = objax.ModuleList([m1, m2])
     vars_list = list(module_list.vars())
     self.assertEqual(len(vars_list), 2)
     self.assertIs(vars_list[0], m1.v1)
     self.assertIs(vars_list[1], m2.v1)
Пример #6
0
    def test_trainvar_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])

        def increase():
            m[0].assign(m[0].value + 1)
            return m[0].value

        jit_increase = objax.Jit(increase, m.vars())
        jit_increase()
        self.assertEqual(m[0].value.tolist(), [1., 1.])
Пример #7
0
    def test_trainvar_assign_multivalue(self):
        m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))])

        def increase(x):
            m[0].assign(m[0].value + x)
            return m[0].value

        para_increase = objax.Parallel(increase, m.vars())
        with m.vars().replicate():
            para_increase(jn.arange(8))
        self.assertEqual(m[0].value.tolist(), [4.5, 5.5])
Пример #8
0
    def test_trainvar_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])

        def increase():
            m[0].assign(m[0].value + 1)
            return m[0].value

        para_increase = objax.Parallel(increase, m.vars())
        with m.vars().replicate():
            para_increase()
        self.assertEqual(m[0].value.tolist(), [1., 1.])
Пример #9
0
 def test_module_wrapper(self):
     """Unit test for objax.ModuleWrapper."""
     m1 = SampleModule(3, 5)
     m2 = SampleModule(5, 7)
     var_collection = objax.ModuleList([m1, m2]).vars()
     module_wrapper = objax.ModuleWrapper(var_collection)
     module_wrapper_vars = module_wrapper.vars()
     self.assertEqual(len(var_collection), len(module_wrapper_vars))
     for k, v in var_collection.items():
         new_key = f'({module_wrapper.__class__.__name__}){k}'
         self.assertIn(new_key, module_wrapper_vars)
         self.assertIs(module_wrapper_vars[new_key], v)
Пример #10
0
    def test_trainvar_assign_multivalue(self):
        m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))])

        def increase(x):
            m[0].assign(m[0].value + x)
            return x * 2

        x = np.arange(10)[:, None]
        vec_increase = objax.Vectorize(increase, m.vars())
        y = vec_increase(x)
        self.assertEqual(y.tolist(), (2 * np.arange(10))[:, None].tolist())
        self.assertEqual(m[0].value.tolist(), [5.5, 6.5])
Пример #11
0
    def test_trainvar_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])

        def increase(x):
            m[0].assign(m[0].value + 1)
            return x + 1

        x = np.arange(10)[:, None]
        vec_increase = objax.Vectorize(increase, m.vars())
        y = vec_increase(x)
        self.assertEqual(y.tolist(), np.arange(1, 11)[:, None].tolist())
        self.assertEqual(m[0].value.tolist(), [1., 1.])
Пример #12
0
    def test_trainvar_and_ref_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])
        m.append(objax.TrainRef(m[0]))

        def increase():
            m[0].assign(m[0].value + 1)
            m[1].assign(m[1].value + 1)
            return m[0].value

        para_increase = objax.Parallel(increase, m.vars())
        with m.vars().replicate():
            para_increase()
        self.assertEqual(m[0].value.tolist(), [2., 2.])
Пример #13
0
    def test_trainvar_and_ref_assign(self):
        m = objax.ModuleList([objax.TrainVar(jn.zeros(2))])
        m.append(objax.TrainRef(m[0]))

        def increase():
            m[0].assign(m[0].value + 1)
            m[1].assign(m[1].value + 1)
            return m[0].value

        jit_increase = objax.Jit(increase, m.vars())
        v = jit_increase()
        self.assertEqual(v.tolist(), [2., 2.])
        self.assertEqual(m[0].value.tolist(), [2., 2.])
Пример #14
0
    def __init__(self):

        num_channels = 4  # 3 from RGB_t1 + 1 from dither_t0

        self.encoders = objax.ModuleList()
        k = 7
        for num_output_channels in [32, 64, 128, 128]:
            self.encoders.append(
                EncoderBlock(num_channels, num_output_channels, k))
            k = 3
            num_channels = num_output_channels

        self.decoders = objax.ModuleList()
        for num_output_channels in [128, 64, 32, 16]:
            self.decoders.append(
                DecoderBlock(num_channels, num_output_channels))
            num_channels = num_output_channels

        self.logits = Conv2D(num_channels,
                             nout=1,
                             strides=1,
                             k=1,
                             w_init=xavier_normal)
Пример #15
0
 def test_file_load_save_references(self):
     a = objax.nn.Conv2D(16, 16, 3)
     b = objax.nn.Conv2D(16, 16, 3)
     c = objax.nn.Conv2D(16, 16, 3)
     refs = objax.ModuleList([objax.TrainRef(a.w)])
     crefs = objax.ModuleList([objax.TrainRef(c.w)])
     self.assertFalse(jn.array_equal(a.w.value, b.w.value))
     with io.BytesIO() as f:
         objax.io.save_var_collection(f, a.vars())
         size = f.tell()
         f.seek(0)
         objax.io.save_var_collection(f, a.vars() + refs.vars())
         self.assertEqual(size, f.tell())
         f.seek(0)
         objax.io.load_var_collection(f, b.vars())
         f.seek(0)
         with self.assertRaises(ValueError):
             objax.io.load_var_collection(f, refs.vars() + c.vars())
         f.seek(0)
         objax.io.load_var_collection(f, crefs.vars() + c.vars())
     self.assertEqual(a.w.value.dtype, b.w.value.dtype)
     self.assertEqual(a.w.value.dtype, c.w.value.dtype)
     self.assertTrue(jn.array_equal(a.w.value, b.w.value))
     self.assertTrue(jn.array_equal(a.w.value, c.w.value))
Пример #16
0
 def __init__(self, transforms: List[Module]) -> None:
     super().__init__()
     self._transforms = objax.ModuleList(transforms)