Ejemplo n.º 1
0
 def plot_samples(epoch=None):
     epoch = epoch or loop.epoch
     with tk.layers.scoped_eval_mode(vae), T.no_grad():
         logits = vae.p(n_z=100)['x'].distribution.logits
         images = T.reshape(
             T.cast(T.clip(T.nn.sigmoid(logits) * 255., 0., 255.), dtype=T.uint8),
             [-1, 28, 28],
         )
     utils.save_images_collection(
         images=T.to_numpy(images),
         filename=exp.abspath(f'plotting/{epoch}.png'),
         grid_size=(10, 10),
     )
Ejemplo n.º 2
0
    def test_get_activation_class(self):
        x = T.random.randn([2, 3, 4])

        for origin_name, factory, args, kwargs, expected in [
            ('Linear', None, None, None, None),
            ('ReLU', tk.layers.ReLU, (), {}, T.nn.relu(x)),
            ('Leaky_ReLU', tk.layers.LeakyReLU, (), {}, T.nn.leaky_relu(x)),
            ('Leaky_ReLU', tk.layers.LeakyReLU, (0.2, ), {},
             T.nn.leaky_relu(x, 0.2)),
            ('Leaky_ReLU', tk.layers.LeakyReLU, (), {
                'negative_slope': 0.2
            }, T.nn.leaky_relu(x, 0.2)),
            ('Sigmoid', tk.layers.Sigmoid, (), {}, T.nn.sigmoid(x)),
            ('Tanh', tk.layers.Tanh, (), {}, T.tanh(x)),
            ('HardTanh', tk.layers.HardTanh, (), {}, T.clip(x, -1., 1.)),
            ('HardTanh', tk.layers.HardTanh, (-2., 3.), {}, T.clip(x, -2.,
                                                                   3.)),
            ('HardTanh', tk.layers.HardTanh, (), {
                'min_val': -2.,
                'max_val': 3.
            }, T.clip(x, -2., 3.)),
            ('Log_Softmax', tk.layers.LogSoftmax, (), {}, T.nn.log_softmax(x)),
        ]:
            name_candidates = (None, ) if origin_name is None else (
                origin_name, origin_name.lower(), origin_name.replace('_', ''),
                origin_name.replace('_', '').lower())
            for name in name_candidates:
                err_msg = f'{name}, {factory}, {args}, {kwargs}, {expected}'
                self.assertIs(tk.layers.get_activation_class(name), factory)
                if factory is not None:
                    assert_allclose(factory(*args, **kwargs)(x),
                                    expected,
                                    err_msg=err_msg)

        # unsupported activation
        with pytest.raises(ValueError,
                           match='Unsupported activation: invalid'):
            _ = tk.layers.get_activation_class('invalid')
Ejemplo n.º 3
0
 def naive_clip_by_value(grads, clip_min, clip_max):
     return [T.clip(g, clip_min, clip_max) for g in grads]