def __init__(self, num_models, dense_kernel_size=32, embedding_dim=32, seed=0, logit_temp=1.0, orthogonal_init=True): if num_models <= 1: raise Exception("requires at least two models") self.num_models = num_models self.logit_temp = logit_temp key = random.PRNGKey(seed) subkeys = random.split(key, 8) # conv stack kernels and biases if orthogonal_init: initialiser = orthogonal else: initialiser = he_normal self.conv_kernels = objax.ModuleList() self.conv_biases = objax.ModuleList() input_channels = 3 for i, output_channels in enumerate([32, 64, 64, 64, 64, 64]): self.conv_kernels.append( TrainVar(initialiser()( subkeys[i], (num_models, 3, 3, input_channels, output_channels)))) self.conv_biases.append( TrainVar(jnp.zeros((num_models, output_channels)))) input_channels = output_channels # dense kernels and biases self.dense_kernels = TrainVar(initialiser()( subkeys[6], (num_models, 1, 1, output_channels, dense_kernel_size))) self.dense_biases = TrainVar(jnp.zeros( (num_models, dense_kernel_size))) # embeddings kernel; no bias or non linearity. if orthogonal_init: initialiser = orthogonal else: initialiser = glorot_normal self.embedding_kernels = TrainVar(initialiser()( subkeys[7], (num_models, 1, 1, dense_kernel_size, embedding_dim)))
def test_gradvalues_constant(self): """Test if constants are preserved.""" # Set data ndim = 1 data = np.array([[1.0], [3.0], [5.0], [-10.0]]) labels = np.array([1.0, 2.0, 3.0, 4.0]) # Set model parameters for linear regression. w = objax.TrainVar(jn.zeros(ndim)) b = objax.TrainVar(jn.ones(1)) m = objax.ModuleList([w, b]) def loss(x, y): pred = jn.dot(x, w.value) + b.value return 0.5 * ((y - pred)**2).mean() # We are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.GradValues(loss, objax.VarCollection({'w': w})) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0]) # When compile with Jit, we are supposed to see the gradient change after the value of b (the constant) changes. gv = objax.Jit(objax.GradValues(loss, objax.VarCollection({'w': w})), m.vars()) g_old, v_old = gv(data, labels) b.assign(-b.value) g_new, v_new = gv(data, labels) self.assertNotEqual(g_old[0][0], g_new[0][0])
def test_ema(self): eps = 1e-6 orig_value_expect = np.array([100.0, -1.0]) for m in [i / 10.0 for i in range(1, 10)]: x = objax.ModuleList([objax.TrainVar(jn.array(orig_value_expect))]) ema = objax.optimizer.ExponentialMovingAverage(x.vars(), momentum=m, debias=True, eps=eps) ema() ema_value_expect = (1 - m) * orig_value_expect / (1 - (1 - eps) * m) def get_tensors(): return x.vars().tensors() get_tensors_ema = ema.replace_vars(get_tensors) ema_value = get_tensors_ema() orig_value = get_tensors() np.testing.assert_allclose(ema_value[0], ema_value_expect, rtol=1e-6) np.testing.assert_allclose(orig_value[0], orig_value_expect, rtol=1e-6)
def test_weight_sharing(self): """Check weight sharing.""" m = objax.ModuleList([objax.TrainVar(jn.ones(3))]) all_vars = m.vars('m1') + m.vars('m2') self.assertEqual(len(all_vars), 2) self.assertEqual(len(list(all_vars)), 1) self.assertEqual(len(all_vars.tensors()), 1) all_vars.assign([x + 1 for x in all_vars.tensors()]) self.assertEqual(m[0].value.sum(), 6)
def test_module_list(self): """Unit test for objax.ModuleList.""" m1 = SimpleModule(3, 5) m2 = SimpleModule(5, 7) module_list = objax.ModuleList([m1, m2]) vars_list = list(module_list.vars()) self.assertEqual(len(vars_list), 2) self.assertIs(vars_list[0], m1.v1) self.assertIs(vars_list[1], m2.v1)
def test_trainvar_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) def increase(): m[0].assign(m[0].value + 1) return m[0].value jit_increase = objax.Jit(increase, m.vars()) jit_increase() self.assertEqual(m[0].value.tolist(), [1., 1.])
def test_trainvar_assign_multivalue(self): m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))]) def increase(x): m[0].assign(m[0].value + x) return m[0].value para_increase = objax.Parallel(increase, m.vars()) with m.vars().replicate(): para_increase(jn.arange(8)) self.assertEqual(m[0].value.tolist(), [4.5, 5.5])
def test_trainvar_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) def increase(): m[0].assign(m[0].value + 1) return m[0].value para_increase = objax.Parallel(increase, m.vars()) with m.vars().replicate(): para_increase() self.assertEqual(m[0].value.tolist(), [1., 1.])
def test_module_wrapper(self): """Unit test for objax.ModuleWrapper.""" m1 = SampleModule(3, 5) m2 = SampleModule(5, 7) var_collection = objax.ModuleList([m1, m2]).vars() module_wrapper = objax.ModuleWrapper(var_collection) module_wrapper_vars = module_wrapper.vars() self.assertEqual(len(var_collection), len(module_wrapper_vars)) for k, v in var_collection.items(): new_key = f'({module_wrapper.__class__.__name__}){k}' self.assertIn(new_key, module_wrapper_vars) self.assertIs(module_wrapper_vars[new_key], v)
def test_trainvar_assign_multivalue(self): m = objax.ModuleList([objax.TrainVar(jn.array((1., 2.)))]) def increase(x): m[0].assign(m[0].value + x) return x * 2 x = np.arange(10)[:, None] vec_increase = objax.Vectorize(increase, m.vars()) y = vec_increase(x) self.assertEqual(y.tolist(), (2 * np.arange(10))[:, None].tolist()) self.assertEqual(m[0].value.tolist(), [5.5, 6.5])
def test_trainvar_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) def increase(x): m[0].assign(m[0].value + 1) return x + 1 x = np.arange(10)[:, None] vec_increase = objax.Vectorize(increase, m.vars()) y = vec_increase(x) self.assertEqual(y.tolist(), np.arange(1, 11)[:, None].tolist()) self.assertEqual(m[0].value.tolist(), [1., 1.])
def test_trainvar_and_ref_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) m.append(objax.TrainRef(m[0])) def increase(): m[0].assign(m[0].value + 1) m[1].assign(m[1].value + 1) return m[0].value para_increase = objax.Parallel(increase, m.vars()) with m.vars().replicate(): para_increase() self.assertEqual(m[0].value.tolist(), [2., 2.])
def test_trainvar_and_ref_assign(self): m = objax.ModuleList([objax.TrainVar(jn.zeros(2))]) m.append(objax.TrainRef(m[0])) def increase(): m[0].assign(m[0].value + 1) m[1].assign(m[1].value + 1) return m[0].value jit_increase = objax.Jit(increase, m.vars()) v = jit_increase() self.assertEqual(v.tolist(), [2., 2.]) self.assertEqual(m[0].value.tolist(), [2., 2.])
def __init__(self): num_channels = 4 # 3 from RGB_t1 + 1 from dither_t0 self.encoders = objax.ModuleList() k = 7 for num_output_channels in [32, 64, 128, 128]: self.encoders.append( EncoderBlock(num_channels, num_output_channels, k)) k = 3 num_channels = num_output_channels self.decoders = objax.ModuleList() for num_output_channels in [128, 64, 32, 16]: self.decoders.append( DecoderBlock(num_channels, num_output_channels)) num_channels = num_output_channels self.logits = Conv2D(num_channels, nout=1, strides=1, k=1, w_init=xavier_normal)
def test_file_load_save_references(self): a = objax.nn.Conv2D(16, 16, 3) b = objax.nn.Conv2D(16, 16, 3) c = objax.nn.Conv2D(16, 16, 3) refs = objax.ModuleList([objax.TrainRef(a.w)]) crefs = objax.ModuleList([objax.TrainRef(c.w)]) self.assertFalse(jn.array_equal(a.w.value, b.w.value)) with io.BytesIO() as f: objax.io.save_var_collection(f, a.vars()) size = f.tell() f.seek(0) objax.io.save_var_collection(f, a.vars() + refs.vars()) self.assertEqual(size, f.tell()) f.seek(0) objax.io.load_var_collection(f, b.vars()) f.seek(0) with self.assertRaises(ValueError): objax.io.load_var_collection(f, refs.vars() + c.vars()) f.seek(0) objax.io.load_var_collection(f, crefs.vars() + c.vars()) self.assertEqual(a.w.value.dtype, b.w.value.dtype) self.assertEqual(a.w.value.dtype, c.w.value.dtype) self.assertTrue(jn.array_equal(a.w.value, b.w.value)) self.assertTrue(jn.array_equal(a.w.value, c.w.value))
def __init__(self, transforms: List[Module]) -> None: super().__init__() self._transforms = objax.ModuleList(transforms)