Пример #1
0
    def forward(
        self,
        encodings: torch.Tensor,
        quantized: torch.Tensor,
        cond: Tuple[torch.Tensor, ...],
        decode_step=None,
        decode_idx=None,
    ):
        if self.training:
            assert decode_step is None and decode_idx is None  # FIXME: not sure

        return_dict = dict()
        """ Compute generative logits """

        return_dict.update(
            self._core(embeddings=quantized,
                       cond=cond,
                       decode_step=decode_step,
                       decode_idx=decode_idx))
        """ Compute generative loss """

        gen_loss = self.gen_loss(shift_dim(return_dict['gen_logits'], -1, 1),
                                 encodings)
        return_dict.update(loss=gen_loss)

        return return_dict
Пример #2
0
    def forward(self, x):
        x = self.stem(x)
        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        x = self.group4(x)
        if self.pool:
            # (b, resnet_dim)
            dim = [2 + i for i in range(self.n_dim)]
            x = torch.mean(x, dim=dim).view(x.shape[0], -1)
        else:
            # (b, t, h, w, resnet_dim)
            x = shift_dim(x, 1, -1)

        return x
Пример #3
0
    def inputs_fn(batch):
        with torch.no_grad():
            videos = batch['video'].to(device, non_blocking=True)  # (b, c, t, h, w)

            cond = []
            if cond_hp['n_cond_frames'] > 0:
                cond_frames = videos[:, :, :cond_hp['n_cond_frames']]
                cond.append(cond_frames)
            if cond_hp['class_cond']:
                cond.append(batch['label'].to(device, non_blocking=True))

            quantized, encodings = vqvae.encode(x=videos, no_flatten=True)

            # latent_shape = (t, h, w, l)
            quantized = shift_dim(quantized, 1, -1)  # (b, d, t, h, w, l) -> (b, t, h, w, l, d)  # channel first -> last
            encodings = encodings.long()

            cond = tuple(cond)
            return dict(encodings=encodings, quantized=quantized, cond=cond,
                        decode_step=None, decode_idx=None)
Пример #4
0
    def forward(self, x):
        """
        :param x: torch.Tensor with shape (b, c, t, h, w)
        """
        return_dict = OrderedDict()
        z = self.pre_vq_conv1(self.encoder(x=x))

        vq_output = self.codebook(z, no_flatten=True)
        dec_inp = vq_output['quantized']
        dec_inp = shift_dim(dec_inp, -1, 1).flatten(1, 2)  # -> (b, l, d, t', h', w') -> (b, l*d, t', h', w')
        x_recon = self.decoder(x=dec_inp)

        commitment_loss = vq_output['commitment_loss']
        recon_loss = F.mse_loss(x_recon, x) / 0.06
        loss = commitment_loss + recon_loss

        return_dict.update(loss=loss,
                           commitment=commitment_loss,
                           recon=recon_loss,
                           perplexity=vq_output['perplexity'])

        return return_dict
Пример #5
0
    def forward(self, q, k, v, decode_step, decode_idx):
        """ Compute multi-head attention
        Args
            q, k, v: a [b, d1, ..., dn, c] tensor or
                     a [b, 1, ..., 1, c] tensor if decode_step is not None
            decode_step: an integer representing the current sampling index in AR ordering
            decode_idx: a tuple representing the current tensor index being sampled

        Returns
            The output after performing attention and any auxiliary losses if relevant
            (aux_loss != 0 only for routing attention)
        """

        # compute k, q, v
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        q = view_range(self.w_qs(q), -1, None, (n_head, d_k))
        k = view_range(self.w_ks(k), -1, None, (n_head, d_k))
        v = view_range(self.w_vs(v), -1, None, (n_head, d_v))

        # b x n_head x seq_len x d
        # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d)
        q = shift_dim(q, -2, 1)
        k = shift_dim(k, -2, 1)
        v = shift_dim(v, -2, 1)

        # axial transformer does not use this caching
        if decode_step is not None:
            # create cache if first iter of sampling
            if self.cache is None:
                if self.causal:
                    k_shape = (
                        q.shape[0],
                        n_head,
                    ) + self.shape + (self.d_k, )
                    v_shape = (
                        q.shape[0],
                        n_head,
                    ) + self.shape + (self.d_v, )
                    self.cache = dict(k=torch.zeros(k_shape,
                                                    dtype=k.dtype,
                                                    device=q.device),
                                      v=torch.zeros(v_shape,
                                                    dtype=v.dtype,
                                                    device=q.device))
                else:
                    # in the non-causal case only need to cache once
                    self.cache = dict(k=k.clone(), v=v.clone())
            if self.causal:
                idx = (slice(None, None), slice(None, None)) + \
                    tuple([slice(i, i + 1) for i in decode_idx])
                self.cache['k'][idx] = k
                self.cache['v'][idx] = v
            k, v = self.cache['k'], self.cache['v']

        a = self.attn(q, k, v, decode_step, decode_idx)

        # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d)
        a = shift_dim(a, 1, -2).flatten(start_dim=-2)
        a = self.fc(a)  # (b x seq_len x embd_dim)

        return a
Пример #6
0
    def sample(self, n, codebook, cond, device, temperature, no_flatten,
               is_root):
        if is_root and os.environ.get('VERBOSE') == '1':
            print(
                f"Need {n} samples, MAX_SAMPLER_PER_BATCH = {MAX_SAMPLES_PER_BATCH}"
            )
        samples = torch.zeros((n, ) + self.shape).long().to(device)
        assert all(
            n == c.shape[0]
            for c in cond), f"cond shapes {[c.shape for c in cond]}, n, {n}"

        for i in range(0, samples.shape[0], MAX_SAMPLES_PER_BATCH):
            if is_root:
                pbar = tqdm(total=np.prod(self.shape))

            samples_subset = samples[i:i + MAX_SAMPLES_PER_BATCH]
            cond_subset = tuple([c[i:i + MAX_SAMPLES_PER_BATCH] for c in cond])

            with torch.no_grad(), self.sample_mode():
                prev_idx = None
                for j, idx in enumerate(self.sample_order()):
                    # idx must be a tuple, and not a list
                    # pytorch tensor indexing is different when using list vs tuple
                    # tuple is indexing, list is gather
                    batch_idx_slice = (slice(None, None), ) + tuple(
                        [slice(i, i + 1) for i in idx])
                    batch_idx = (slice(None, None), ) + idx

                    quantized = codebook.dictionary_lookup(samples_subset,
                                                           no_flatten=True)
                    quantized = shift_dim(quantized, 1, -1)

                    if prev_idx is None:
                        s_inp = samples_subset[
                            batch_idx_slice]  # doesn't really matter what it is
                        q_inp = torch.zeros_like(quantized[batch_idx_slice])
                        s_inp.q_inp = s_inp.to(device), q_inp.to(device)
                    else:
                        s_inp, q_inp = samples_subset[prev_idx], quantized[
                            prev_idx]

                    logits = self(quantized=q_inp,
                                  encodings=s_inp,
                                  cond=cond_subset,
                                  decode_step=j,
                                  decode_idx=idx)['gen_logits']
                    probs = F.softmax(logits / temperature, dim=-1)
                    if probs.shape[0] == 1:
                        probs = probs.squeeze().unsqueeze(0)
                    else:
                        probs = probs.squeeze()
                    samples_subset[batch_idx] = torch.multinomial(
                        probs, 1).squeeze(-1)

                    prev_idx = batch_idx_slice

                    if is_root:
                        pbar.update(1)

                    if os.environ.get('DEBUG') == '1':
                        break

                if is_root:
                    pbar.close()

                samples[i:i + MAX_SAMPLES_PER_BATCH] = samples_subset

                if os.environ.get('DEBUG') == '1':
                    break

        assert samples.shape[0] == n

        encodings = samples
        quantized = codebook.dictionary_lookup(encodings,
                                               no_flatten=no_flatten)

        return quantized, encodings
Пример #7
0
def main_worker(rank, size, args_in):
    global args
    args = args_in
    is_root = rank == 0

    dist.init_process_group(backend='nccl',
                            init_method=f'tcp://localhost:{args.port}',
                            world_size=size,
                            rank=rank)

    assert args.n_samples % size == 0, f'n_samples {args.n_samples} not divisible by size {size}'

    seed = args.seed + rank
    seed_all(seed)
    device = config_device()

    prior_ckpt = torch.load(get_ckpt(args.prior_ckpt), map_location=device)
    vqvae_ckpt = torch.load(get_ckpt(prior_ckpt['vqvae_ckpt']),
                            map_location=device)
    """ Load datasets """
    dset_configs = prior_ckpt['dset_configs']
    cond_hp = prior_ckpt['cond_hp']

    train_loader, test_loader, dset = get_distributed_loaders(
        dset_configs=dset_configs, batch_size=args.n_samples, seed=seed)

    loader = test_loader
    # shuffle the dataset according to some fixed seed
    loader.sampler.set_epoch(seed)
    batch = next(
        iter(loader)
    )  # get batch as early as possible for fixed seed sampling examples

    vqvae, vq_hp = load_model(ckpt=vqvae_ckpt,
                              device=device,
                              freeze_model=True,
                              cond_types=())

    def load_layer_prior(ckpt):
        # must use the same self_gen_types for vae and all prior layers
        cond_types, cond_hp = config_cond_types(cond_hp=ckpt['cond_hp'],
                                                dset=dset)
        # freeze all previous priors, not the current one
        prior, hp = load_model(ckpt=ckpt,
                               device=device,
                               freeze_model=True,
                               cond_types=cond_types)
        codebook = vqvae.codebook
        return prior, hp, codebook

    latent_shape = vqvae.latent_shape
    quantized_shape = vqvae.quantized_shape

    if is_root:
        print('latent shapes', latent_shape)
        print('quantized shape', quantized_shape)
        print('total latents', np.prod(latent_shape))

    prior, prior_hp, codebook = load_layer_prior(prior_ckpt)
    if is_root:
        print(
            f"Loaded vqvae at iteration {vqvae_ckpt['iteration']}, loss = {vqvae_ckpt['best_loss']}"
        )
        print(
            f"Loaded GPT at iteration {prior_ckpt['iteration']}, loss {prior_ckpt['best_loss']}"
        )
    """ Generate samples """
    sample_fn = functools.partial(sample,
                                  cond_hp=cond_hp,
                                  vae=vqvae,
                                  prior=prior,
                                  codebook=codebook,
                                  device=device,
                                  temperature=args.temperature,
                                  rank=rank,
                                  size=size)

    gathered_samples, gathered_cond = sample_fn(n_samples=args.n_samples,
                                                batch=batch,
                                                gather=True)

    if is_root:
        os.makedirs(args.output_dir, exist_ok=True)
        # (n, c, t, h, w) -> (n, t, c, h, w)
        samples = shift_dim(gathered_samples, 2, 1)
        T = samples.shape[1]
        save_image(samples.flatten(end_dim=1),
                   osp.join(args.output_dir, "samples.png"),
                   nrow=T)

        # (n, t, c, h, w) -> (n, t, h, w, c)
        samples = shift_dim(samples, 2, -1)
        samples = (samples.cpu().numpy() * 255).astype('uint8')

        for i in range(min(args.n_samples, MAX_N_SAMPLES_TO_VIDEO)):
            skvideo.io.vwrite(osp.join(args.output_dir, f'samples_{i}.mp4'),
                              samples[i],
                              inputdict={'-r': '5'})
        np.save(osp.join(args.output_dir, 'samples.npy'), samples)

        print('outputted videos to:', args.output_dir)
Пример #8
0
 def forward(self, x):
     x = shift_dim(x, 1, -1)
     x = self.norm(x)
     x = shift_dim(x, -1, 1)
     return x