Esempio n. 1
0
    def random_generate(self, num_images, path):

        # Generate from the uniform prior of the base model
        indices = F.randint(low=0,
                            high=self.num_embedding,
                            shape=[num_images] + self.latent_shape)
        indices = F.reshape(indices, (-1, ), inplace=True)
        quantized = F.embed(indices, self.base_model.vq.embedding_weight)
        quantized = F.transpose(
            quantized.reshape([num_images] + self.latent_shape +
                              [quantized.shape[-1]]), (0, 3, 1, 2))

        img_gen_uniform_prior = self.base_model(quantized,
                                                quantized_as_input=True,
                                                test=True)

        # Generate images using pixelcnn prior
        indices = nn.Variable.from_numpy_array(
            np.zeros(shape=[num_images] + self.latent_shape))
        labels = F.randint(low=0, high=self.num_classes, shape=(num_images, 1))
        labels = F.one_hot(labels, shape=(self.num_classes, ))

        # Sample from pixelcnn - pixel by pixel
        import torch  # Numpy behavior is different and not giving correct output
        for i in range(self.latent_shape[0]):
            for j in range(self.latent_shape[1]):
                quantized = F.embed(indices.reshape((-1, )),
                                    self.base_model.vq.embedding_weight)
                quantized = F.transpose(
                    quantized.reshape([num_images] + self.latent_shape +
                                      [quantized.shape[-1]]), (0, 3, 1, 2))
                indices_sample = self.prior(quantized, labels)
                indices_prob = F.reshape(indices_sample,
                                         indices.shape +
                                         (indices_sample.shape[-1], ),
                                         inplace=True)[:, i, j]
                indices_prob = F.softmax(indices_prob)

                indices_prob_tensor = torch.from_numpy(indices_prob.d)
                sample = indices_prob_tensor.multinomial(1).squeeze().numpy()
                indices[:, i, j] = sample

        print(indices.d)
        quantized = F.embed(indices.reshape((-1, )),
                            self.base_model.vq.embedding_weight)
        quantized = F.transpose(
            quantized.reshape([num_images] + self.latent_shape +
                              [quantized.shape[-1]]), (0, 3, 1, 2))

        img_gen_pixelcnn_prior = self.base_model(quantized,
                                                 quantized_as_input=True,
                                                 test=True)

        self.save_image(img_gen_uniform_prior,
                        os.path.join(path, 'generate_uniform.png'))
        self.save_image(img_gen_pixelcnn_prior,
                        os.path.join(path, 'generate_pixelcnn.png'))

        print('Random labels generated for pixelcnn prior:',
              list(F.max(labels, axis=1, only_index=True).d))
Esempio n. 2
0
def test_randint_forward(seed, ctx, func_name, low, high, shape):
    with nn.context_scope(ctx):
        o = F.randint(low, high, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    assert np.all(o.d < high)
    assert np.all(o.d >= low)
Esempio n. 3
0
def random_flip(x):
    r"""Random flipping sign of a Variable.

    Args:
        x (nn.Variable): Input Variable.
    """
    shape = (x.shape[0], 1, 1)
    scale = 2 * F.randint(0, 2, shape=shape) - 1
    return x * scale
Esempio n. 4
0
def test_randint_forward(seed, ctx, func_name, low, high, shape):
    with nn.context_scope(ctx):
        o = F.randint(low, high, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    # NOTE: The following should be < high,
    # but use <= high because std::uniform_random contains a bug.
    assert np.all(o.d <= high)
    assert np.all(o.d >= low)
Esempio n. 5
0
    def build_train_graph(self,
                          x,
                          t=None,
                          dropout=0,
                          noise=None,
                          loss_scaling=None):
        B, C, H, W = x.shape
        if self.randflip:
            x = F.random_flip(x)
            assert x.shape == (B, C, H, W)

        if t is None:
            t = F.randint(low=0,
                          high=self.diffusion.num_timesteps,
                          shape=(B, ))
            # F.randint could return high with very low prob. Workaround to avoid this.
            t = F.clip_by_value(t,
                                min=0,
                                max=self.diffusion.num_timesteps - 0.5)

        loss_dict = self.diffusion.train_loss(model=partial(self._denoise,
                                                            dropout=dropout),
                                              x_start=x,
                                              t=t,
                                              noise=noise)
        assert isinstance(loss_dict, AttrDict)

        # setup training loss
        loss_dict.batched_loss = loss_dict.mse
        if is_learn_sigma(self.model_var_type):
            assert "vlb" in loss_dict
            loss_dict.batched_loss += loss_dict.vlb * 1e-3
            # todo: implement loss aware sampler

        if loss_scaling is not None and loss_scaling > 1:
            loss_dict.batched_loss *= loss_scaling

        # setup flat training loss
        loss_dict.loss = F.mean(loss_dict.batched_loss)
        assert loss_dict.batched_loss.shape == t.shape == (B, )

        # Keep interval values to compute loss for each quantile
        t.persistent = True
        for v in loss_dict.values():
            v.persistent = True

        return loss_dict, t
Esempio n. 6
0
def test_randint_forward(seed, ctx, func_name, low, high, shape):
    with nn.context_scope(ctx):
        o = F.randint(low, high, shape, seed=seed)
    assert o.shape == tuple(shape)
    assert o.parent.name == func_name
    o.forward()
    # NOTE: The following should be < high,
    # but use <= high because std::uniform_random contains a bug.
    assert np.all(o.d <= high)
    assert np.all(o.d >= low)

    # Checking recomputation
    func_args = [low, high, shape, seed]
    recomputation_test(rng=None,
                       func=F.randint,
                       vinputs=[],
                       func_args=func_args,
                       func_kwargs={},
                       ctx=ctx)
Esempio n. 7
0
    def __call__(self,
                 batch_size,
                 style_noises,
                 truncation_psi=1.0,
                 return_latent=False,
                 mixing_layer_index=None,
                 dlatent_avg_beta=0.995):

        with nn.parameter_scope(self.global_scope):
            # normalize noise inputs
            for i in range(len(style_noises)):
                style_noises[i] = F.div2(
                    style_noises[i],
                    F.pow_scalar(F.add_scalar(F.mean(style_noises[i]**2.,
                                                     axis=1,
                                                     keepdims=True),
                                              1e-8,
                                              inplace=False),
                                 0.5,
                                 inplace=False))

            # get latent code
            w = [
                mapping_network(style_noises[0],
                                outmaps=self.mapping_network_dim,
                                num_layers=self.mapping_network_num_layers)
            ]
            w += [
                mapping_network(style_noises[1],
                                outmaps=self.mapping_network_dim,
                                num_layers=self.mapping_network_num_layers)
            ]

            dlatent_avg = nn.parameter.get_parameter_or_create(
                name="dlatent_avg", shape=(1, 512))

            # Moving average update of dlatent_avg
            batch_avg = F.mean((w[0] + w[1]) * 0.5, axis=0, keepdims=True)
            update_op = F.assign(
                dlatent_avg, lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            update_op.name = 'dlatent_avg_update'
            dlatent_avg = F.identity(dlatent_avg) + 0 * update_op

            # truncation trick
            w = [lerp(dlatent_avg, _, truncation_psi) for _ in w]

            # generate output from generator
            constant_bc = nn.parameter.get_parameter_or_create(
                name="G_synthesis/4x4/Const/const",
                shape=(1, 512, 4, 4),
                initializer=np.random.randn(1, 512, 4, 4).astype(np.float32))
            constant_bc = F.broadcast(constant_bc,
                                      (batch_size, ) + constant_bc.shape[1:])

            if mixing_layer_index is None:
                mixing_layer_index_var = F.randint(1,
                                                   len(self.resolutions) * 2,
                                                   (1, ))
            else:
                mixing_layer_index_var = F.constant(val=mixing_layer_index,
                                                    shape=(1, ))
            mixing_switch_var = F.clip_by_value(
                F.arange(0,
                         len(self.resolutions) * 2) - mixing_layer_index_var,
                0, 1)
            mixing_switch_var_re = F.reshape(
                mixing_switch_var, (1, mixing_switch_var.shape[0], 1),
                inplace=False)
            w0 = F.reshape(w[0], (batch_size, 1, w[0].shape[1]), inplace=False)
            w1 = F.reshape(w[1], (batch_size, 1, w[0].shape[1]), inplace=False)
            w_mixed = w0 * mixing_switch_var_re + \
                w1 * (1 - mixing_switch_var_re)

            rgb_output = self.synthesis(w_mixed, constant_bc)

            if return_latent:
                return rgb_output, w_mixed
            else:
                return rgb_output