def test_ConcatRecorder(self): arrays = [T.random.randn([2, 4]), T.random.randn([3, 4])] standard_recorder_check( self, ConcatRecorder(), arrays, T.concat(arrays, axis=0), )
def test_push_to_stack(self): rec1 = RecorderManager({ 'name1': ConcatRecorder, 'name2': ConcatRecorder, }) rec2 = RecorderManager({ 'name1': ConcatRecorder, }) tensors = { 'name1': [T.random.randn([2, 4]), T.random.randn([3, 4])], 'name2': [T.random.randn([5]), T.random.randn([6])], } answers = {key: T.concat(tensors[key], axis=0) for key in tensors} def populate(): rec1.clear() rec2.clear() for name in tensors: for t in tensors[name]: record_tensor(name, t) # test empty self.assertFalse(has_recorder_manager()) populate() self.assertEqual(rec1.get_all(), {}) self.assertEqual(rec2.get_all(), {}) def g(records, names): self.assertEqual(len(records), len(names)) self.assertEqual(sorted(names), sorted(records)) for n in names: assert_allclose(records[n], answers[n]) # test one recorder self.assertFalse(has_recorder_manager()) with rec1.push_to_stack() as rm: self.assertIs(rm, rec1) self.assertTrue(has_recorder_manager()) populate() self.assertFalse(has_recorder_manager()) g(rec1.get_all(), ['name1', 'name2']) g(rec2.get_all(), []) # test two recorders self.assertFalse(has_recorder_manager()) with rec1.push_to_stack(), rec2.push_to_stack(): self.assertTrue(has_recorder_manager()) populate() self.assertFalse(has_recorder_manager()) g(rec1.get_all(), ['name1', 'name2']) g(rec2.get_all(), ['name1'])
def do_check(secondary, scale_type): x = T.random.randn( make_conv_shape(batch_shape, num_features, [6, 7, 8][:spatial_ndims])) n1, n2 = (num_features // 2), (num_features - num_features // 2) # construct the instance shift_and_pre_scale = (shift_and_pre_scale_2 if secondary else shift_and_pre_scale_1) flow = cls(shift_and_pre_scale, scale=scale_type, secondary=secondary, sigmoid_scale_bias=sigmoid_scale_bias) ctx.assertIn(f'secondary={secondary}', repr(flow)) flow = tk.layers.jit_compile(flow) # obtain the expected output channel_axis = get_channel_axis(spatial_ndims) x1, x2 = T.split(x, [n1, n2], axis=channel_axis) if secondary: x1, x2 = x2, x1 y1 = x1 shift, pre_scale = shift_and_pre_scale(x1) if scale_type == 'exp' or scale_type is ExpScale: scale = ExpScale() elif scale_type == 'sigmoid' or scale_type is SigmoidScale: scale = SigmoidScale(pre_scale_bias=sigmoid_scale_bias) elif scale_type == 'linear' or scale_type is LinearScale: scale = LinearScale() elif isinstance(scale_type, Scale) or tk.layers.is_jit_layer(scale_type): scale = scale_type else: raise ValueError(f'Invalid value for `scale`: {scale_type}') y2, log_det = scale(x2 + shift, pre_scale, event_ndims=spatial_ndims + 1, compute_log_det=True) if secondary: y1, y2 = y2, y1 expected_y = T.concat([y1, y2], axis=channel_axis) expected_log_det = log_det # now check the flow flow_standard_check(ctx, flow, x, expected_y, expected_log_det, T.random.randn(batch_shape))
def test_record_and_clear(self): rm = RecorderManager(default_factory=ConcatRecorder) tensors = { 'name1': [T.random.randn([2, 4]), T.random.randn([3, 4])], 'name2': [T.random.randn([5]), T.random.randn([6])], } answers = {key: T.concat(tensors[key], axis=0) for key in tensors} def f(): for name in tensors: self.assertIsNone(rm.get(name)) self.assertEqual(list(rm.iter_all()), []) self.assertEqual(rm.get_all(), {}) for name in tensors: for t in tensors[name]: rm.record(name, t) for name in tensors: assert_allclose(rm.get(name), answers[name]) def g(records, names): if not isinstance(records, dict): records = {k: v for k, v in records} self.assertEqual(len(records), len(names)) self.assertEqual(sorted(names), sorted(records)) for n in names: assert_allclose(records[n], answers[n]) g(rm.iter_all(), ['name1', 'name2']) g(rm.get_all(), ['name1', 'name2']) filter_ = lambda n: n == 'name1' g(rm.iter_all(filter_), ['name1']) g(rm.get_all(filter_), ['name1']) f() rm.clear() f()
def f(modules, self_module, self_weight, use_bias, normalizer, activation, merge_mode): if n_partitions == 1: adj = in_adj[0] else: adj = in_adj[:len(modules)] out_shape = list(add_out_shape) if merge_mode == 'concat': out_shape[feature_axis] *= len(modules) + int( self_module is not None) bias_store = (SimpleParamStore(out_shape, initializer=tk.init.normal) if use_bias else None) layer_kwargs = dict(self_module=self_module, self_weight=self_weight, bias_store=bias_store, normalizer=normalizer, activation=activation, merge_mode=merge_mode) layer = jit_compile( cls(module=modules[0], **layer_kwargs) if n_partitions == 1 else cls(modules=modules, **layer_kwargs)) if isinstance(activation, type): activation_layer = activation() else: activation_layer = activation for x in inputs: # test errors if len(modules) > 1: with pytest.raises( Exception, match=r'`adj` is expected to have .* element' r'\(s\), but got .*'): _ = layer(x, in_adj[:len(modules) - 1]) if T.rank(x) == value_ndims + 1: with pytest.raises( Exception, match='`input` is expected to be at least .*d' ): _ = layer(x[0], adj) # obtain valid output y = layer(x, adj) self.assertEqual(T.shape(y), T.shape(x)[:-value_ndims] + out_shape) # compute the expected output def g(m, x): m_out, m_front = T.flatten_to_ndims(x, value_ndims + 1) m_out = m(m_out) m_out = T.unflatten_from_ndims(m_out, m_front) return m_out outs = [] for m, a in zip(modules, in_adj): m_out = T.as_tensor( np.reshape( np.dot(T.sparse.to_numpy(a), T.to_numpy(x).reshape([50, -1])), x.shape)) outs.append(g(m, m_out)) if self_module is not None: outs.append(g(self_module, x)) if merge_mode == 'add': out = T.add_n(outs) elif merge_mode == 'concat': out = T.concat(outs, axis=feature_axis) if bias_store is not None: out = out + bias_store() if normalizer is not None: out = normalizer(out) if activation is not None: out = activation_layer(out) # assert the output is expected assert_allclose_(y, out)
def check_split_flow(ctx, spatial_ndims: int, num_features: int, cls, x_sections, left, right, **kwargs): batch_shape = [2] y_sections = kwargs.get("y_sections", x_sections) x_axis = kwargs.get("x_axis", get_channel_axis(spatial_ndims)) y_axis = kwargs.get("y_axis", x_axis) x = T.random.randn( make_conv_shape(batch_shape, num_features, [6, 7, 8][:spatial_ndims])) input_log_det = T.random.randn(batch_shape) # without right if y_axis == x_axis: flow = cls(x_sections, left, None, **kwargs) ctx.assertIn(f'x_sections={x_sections}', repr(flow)) ctx.assertIn(f'y_sections={y_sections}', repr(flow)) ctx.assertIn(f'x_axis={x_axis}', repr(flow)) ctx.assertIn(f'y_axis={y_axis}', repr(flow)) flow = tk.layers.jit_compile(flow) x1, x2 = T.split(x, x_sections, axis=x_axis) y1, expected_log_det = left(x1, compute_log_det=True) y2 = x2 expected_y = T.concat([y1, y2], axis=y_axis) flow_standard_check(ctx, flow, x, expected_y, expected_log_det, input_log_det) # with right flow = cls(x_sections, left, right, **kwargs) flow = tk.layers.jit_compile(flow) x1, x2 = T.split(x, x_sections, axis=x_axis) y1, expected_log_det = left(x1, compute_log_det=True) y2, expected_log_det = right(x2, input_log_det=expected_log_det, compute_log_det=True) expected_y = T.concat([y1, y2], axis=y_axis) flow_standard_check(ctx, flow, x, expected_y, expected_log_det, input_log_det) # test argument error with pytest.raises(ValueError, match='`x_sections` must be a sequence of ' 'two positive integers'): _ = cls([1, 2, 3], left) with pytest.raises(ValueError, match='`x_sections` must be a sequence of ' 'two positive integers'): _ = cls([-1, 2], left) with pytest.raises(ValueError, match='`y_sections` must be None or a sequence of ' 'two positive integers'): _ = cls([2, 3], left, right, y_sections=[1, 2, 3]) with pytest.raises(ValueError, match='`y_sections` must be None or a sequence of ' 'two positive integers'): _ = cls([2, 3], left, right, y_sections=[-1, 2]) with pytest.raises(TypeError, match='`left` is not a flow'): _ = cls([2, 3], tk.layers.Linear(2, 3)) with pytest.raises(TypeError, match='`right` is not a flow'): _ = cls([2, 3], left, tk.layers.Linear(2, 3))
def test_embedding(self): n_channels = 3 n_embeddings = n_samples for spatial_ndims in (0, 1, 2, 3): w_shape = make_conv_shape([], n_channels, [4, 5, 6][:spatial_ndims]) w_size = int(np.prod(w_shape)) layer = getattr(tk.layers, (f'Embedding{spatial_ndims}d' if spatial_ndims > 0 else 'Embedding'))(n_embeddings, w_shape) weight = layer.weight # check the weight self.assertEqual(T.shape(weight), [n_embeddings] + w_shape) reduce_axis = list(range(len(w_shape) + 1)) reduce_axis.pop(-1 if T.IS_CHANNEL_LAST else 1) w_mean = np.average(T.to_numpy(weight), axis=tuple(reduce_axis)) np.testing.assert_array_less( w_mean, 3. / np.sqrt(n_samples * w_size / n_channels)) # check the output layer = jit_compile(layer) weight_array = T.to_numpy(weight) for in_shape in ([7], [7, 8]): indices = T.random.randint(0, n_samples, in_shape) indices = T.concat([indices, indices[:3]], axis=0) # check the output output = layer(indices) assert_allclose(output, T.embedding(weight, indices)) # check the grad if spatial_ndims in (0, 1): out_sum = T.reduce_sum(output**2) [grad] = T.grad([out_sum], [weight]) expected_grad = np.zeros(T.shape(weight)) for idx in T.to_numpy(indices).reshape([-1]): expected_grad[idx] += 2. * weight_array[idx] assert_allclose(grad, expected_grad) # test the constructor error if spatial_ndims > 0: with pytest.raises( ValueError, match=f'`embedding_size` must be a int list ' f'with {spatial_ndims + 1} elements'): _ = getattr(tk.layers, f'Embedding{spatial_ndims}d')(n_embeddings, w_shape[:-1]) # test no grad layer = Embedding(n_embeddings, n_channels, freeze=True) weight = layer.weight self.assertEqual(T.shape(weight), [n_embeddings, n_channels]) layer = jit_compile(layer) indices = T.random.randint(0, n_samples, [7, 8]) output = layer(indices) assert_allclose(output, T.embedding(weight, indices)) out_sum = T.reduce_sum(output**2) try: [grad] = T.grad([out_sum], [weight]) except Exception: pass else: self.assertTrue(T.is_null_grad(weight, grad)) # test errors with pytest.raises(ValueError, match='`embedding_size` must not be empty'): _ = Embedding(n_embeddings, [])