def test_InverseFlow(self): original_flow = tk.layers.jit_compile(_MyFlow()) flow = InverseFlow(original_flow) self.assertIs(flow.original_flow, original_flow) self.assertIs(flow.invert(), original_flow) flow = tk.layers.jit_compile(flow) self.assertEqual(flow.get_x_event_ndims(), 2) self.assertEqual(flow.get_y_event_ndims(), 1) self.assertTrue(flow.is_explicitly_invertible()) x = T.random.randn([2, 3, 4, 1]) expected_y = T.reshape((x - 1.) * 0.5, [2, 3, 4]) expected_log_det = -T.full([2, 3], math.log(2.) * 4) input_log_det = T.random.randn([2, 3]) flow_standard_check(self, flow, x, expected_y, expected_log_det, input_log_det) with pytest.raises(TypeError, match='`flow` must be an explicitly invertible flow'): _ = InverseFlow(tk.layers.Linear(5, 3)) base_flow = _MyFlow() base_flow.explicitly_invertible = False with pytest.raises(TypeError, match='`flow` must be an explicitly invertible flow'): _ = InverseFlow(tk.layers.jit_compile(base_flow))
def test_batch_norm(self): eps = T.EPSILON for spatial_ndims in (0, 1, 2, 3): cls = getattr(tk.layers, ('BatchNorm' if not spatial_ndims else f'BatchNorm{spatial_ndims}d')) layer = cls(5, momentum=0.1, epsilon=eps) self.assertIn('BatchNorm', repr(layer)) self.assertTrue(tk.layers.is_batch_norm(layer)) layer = tk.layers.jit_compile(layer) # layer output x = T.random.randn(make_conv_shape( [3], 5, [6, 7, 8][:spatial_ndims] )) set_train_mode(layer) _ = layer(x) set_train_mode(layer, False) y = layer(x) set_train_mode(layer, True) set_eval_mode(layer) y2 = layer(x) # manually compute the expected output if T.backend_name == 'PyTorch': dst_shape = [-1] + [1] * spatial_ndims weight = T.reshape(layer.weight, dst_shape) bias = T.reshape(layer.bias, dst_shape) running_mean = T.reshape(layer.running_mean, dst_shape) running_var = T.reshape(layer.running_var, dst_shape) expected = (((x - running_mean) / T.sqrt(running_var + eps)) * weight + bias) else: raise RuntimeError() # check output assert_allclose(y, expected, rtol=1e-4, atol=1e-6) assert_allclose(y2, expected, rtol=1e-4, atol=1e-6) # check invalid dimensions with pytest.raises(Exception, match='torch|only supports .d input'): _ = layer( T.random.randn(make_conv_shape( [], 5, [6, 7, 8][:spatial_ndims] )) )
def plot_samples(epoch=None): epoch = epoch or loop.epoch with tk.layers.scoped_eval_mode(vae), T.no_grad(): logits = vae.p(n_z=100)['x'].distribution.logits images = T.reshape( T.cast(T.clip(T.nn.sigmoid(logits) * 255., 0., 255.), dtype=T.uint8), [-1, 28, 28], ) utils.save_images_collection( images=T.to_numpy(images), filename=exp.abspath(f'plotting/{epoch}.png'), grid_size=(10, 10), )
def test_call(self): flow = tk.layers.jit_compile(_MyFlow()) self.assertEqual(flow.get_x_event_ndims(), 1) self.assertEqual(flow.get_y_event_ndims(), 2) self.assertEqual(flow.is_explicitly_invertible(), True) # test call x = T.random.randn([2, 3, 4]) expected_y = T.reshape(x * 2. + 1., [2, 3, 4, 1]) expected_log_det = T.full([2, 3], math.log(2.) * 4) input_log_det = T.random.randn([2, 3]) flow_standard_check(self, flow, x, expected_y, expected_log_det, input_log_det) # test input shape error with pytest.raises(Exception, match='`input` is required to be at least .*d'): _ = flow(T.random.randn([])) with pytest.raises(Exception, match='`input` is required to be at least .*d'): _ = flow(T.random.randn([3]), inverse=True) # test input_log_det shape error with pytest.raises(Exception, match='The shape of `input_log_det` is not expected'): _ = flow(x, T.random.randn([2, 4])) with pytest.raises(Exception, match='The shape of `input_log_det` is not expected'): _ = flow(expected_y, T.random.randn([2, 4]), inverse=True) # test output_log_det shape error flow = tk.layers.jit_compile(_MyBadFlow()) with pytest.raises(Exception, match='(shape|size)'): _ = flow(x) with pytest.raises(Exception, match='(shape|size)'): _ = flow(x, inverse=True)
def check_invertible_linear(ctx, spatial_ndims: int, invertible_linear_factory, linear_factory, strict: bool,): batch_shape = [2] num_features = 4 spatial_shape = [5, 6, 7][:spatial_ndims] x = T.random.randn(make_conv_shape( batch_shape, num_features, spatial_shape)) # construct the layer flow = invertible_linear_factory(num_features, strict=strict) ctx.assertIn(f'num_features={num_features}', repr(flow)) flow = tk.layers.jit_compile(flow) # derive the expected answer weight, log_det = flow.invertible_matrix( inverse=False, compute_log_det=True) linear_kwargs = {} if spatial_ndims > 0: linear_kwargs['kernel_size'] = 1 linear = linear_factory( num_features, num_features, weight_init=T.reshape(weight, T.shape(weight) + [1] * spatial_ndims), use_bias=False, **linear_kwargs ) x_flatten, front_shape = T.flatten_to_ndims(x, spatial_ndims + 2) expected_y = T.unflatten_from_ndims(linear(x_flatten), front_shape) expected_log_det = T.expand( T.reduce_sum(T.expand(log_det, spatial_shape)), batch_shape) # check the invertible layer flow_standard_check(ctx, flow, x, expected_y, expected_log_det, T.random.randn(batch_shape))
def do_check(batch_shape, scale_type, initialized, dtype): x = T.random.randn(make_conv_shape(batch_shape, num_features, [6, 7, 8][:spatial_ndims]), dtype=dtype) # check construct flow = cls(num_features, scale=scale_type, initialized=initialized, dtype=dtype) ctx.assertIn(f'num_features={num_features}', repr(flow)) ctx.assertIn(f'axis={-(spatial_ndims + 1)}', repr(flow)) ctx.assertIn(f'scale_type={scale_type!r}', repr(flow)) flow = tk.layers.jit_compile(flow) # check initialize if not initialized: # must initialize with sufficient data with pytest.raises(Exception, match='with at least .* dimensions'): _ = flow( T.random.randn(make_conv_shape([], num_features, [6, 7, 8][:spatial_ndims]), dtype=dtype)) # must initialize with inverse = Fale with pytest.raises(Exception, match='`ActNorm` must be initialized with ' '`inverse = False`'): _ = flow(x, inverse=True) # do initialize y, _ = flow(x, compute_log_det=False) y_mean, y_var = T.calculate_mean_and_var( y, axis=[a for a in range(-T.rank(y), 0) if a != channel_axis]) assert_allclose(y_mean, T.zeros([num_features]), rtol=1e-4, atol=1e-6) assert_allclose(y_var, T.ones([num_features]), rtol=1e-4, atol=1e-6) else: y, _ = flow(x, compute_log_det=False) assert_allclose(y, x, rtol=1e-4, atol=1e-6) # prepare for the expected result scale_obj = ExpScale() if scale_type == 'exp' else LinearScale() if T.IS_CHANNEL_LAST: aligned_shape = [num_features] else: aligned_shape = [num_features] + [1] * spatial_ndims bias = T.reshape(flow.bias, aligned_shape) pre_scale = T.reshape(flow.pre_scale, aligned_shape) expected_y, expected_log_det = scale_obj(x + bias, pre_scale, event_ndims=(spatial_ndims + 1), compute_log_det=True) flow_standard_check(ctx, flow, x, expected_y, expected_log_det, T.random.randn(T.shape(expected_log_det)))