def forward( self, encodings: torch.Tensor, quantized: torch.Tensor, cond: Tuple[torch.Tensor, ...], decode_step=None, decode_idx=None, ): if self.training: assert decode_step is None and decode_idx is None # FIXME: not sure return_dict = dict() """ Compute generative logits """ return_dict.update( self._core(embeddings=quantized, cond=cond, decode_step=decode_step, decode_idx=decode_idx)) """ Compute generative loss """ gen_loss = self.gen_loss(shift_dim(return_dict['gen_logits'], -1, 1), encodings) return_dict.update(loss=gen_loss) return return_dict
def forward(self, x): x = self.stem(x) x = self.group1(x) x = self.group2(x) x = self.group3(x) x = self.group4(x) if self.pool: # (b, resnet_dim) dim = [2 + i for i in range(self.n_dim)] x = torch.mean(x, dim=dim).view(x.shape[0], -1) else: # (b, t, h, w, resnet_dim) x = shift_dim(x, 1, -1) return x
def inputs_fn(batch): with torch.no_grad(): videos = batch['video'].to(device, non_blocking=True) # (b, c, t, h, w) cond = [] if cond_hp['n_cond_frames'] > 0: cond_frames = videos[:, :, :cond_hp['n_cond_frames']] cond.append(cond_frames) if cond_hp['class_cond']: cond.append(batch['label'].to(device, non_blocking=True)) quantized, encodings = vqvae.encode(x=videos, no_flatten=True) # latent_shape = (t, h, w, l) quantized = shift_dim(quantized, 1, -1) # (b, d, t, h, w, l) -> (b, t, h, w, l, d) # channel first -> last encodings = encodings.long() cond = tuple(cond) return dict(encodings=encodings, quantized=quantized, cond=cond, decode_step=None, decode_idx=None)
def forward(self, x): """ :param x: torch.Tensor with shape (b, c, t, h, w) """ return_dict = OrderedDict() z = self.pre_vq_conv1(self.encoder(x=x)) vq_output = self.codebook(z, no_flatten=True) dec_inp = vq_output['quantized'] dec_inp = shift_dim(dec_inp, -1, 1).flatten(1, 2) # -> (b, l, d, t', h', w') -> (b, l*d, t', h', w') x_recon = self.decoder(x=dec_inp) commitment_loss = vq_output['commitment_loss'] recon_loss = F.mse_loss(x_recon, x) / 0.06 loss = commitment_loss + recon_loss return_dict.update(loss=loss, commitment=commitment_loss, recon=recon_loss, perplexity=vq_output['perplexity']) return return_dict
def forward(self, q, k, v, decode_step, decode_idx): """ Compute multi-head attention Args q, k, v: a [b, d1, ..., dn, c] tensor or a [b, 1, ..., 1, c] tensor if decode_step is not None decode_step: an integer representing the current sampling index in AR ordering decode_idx: a tuple representing the current tensor index being sampled Returns The output after performing attention and any auxiliary losses if relevant (aux_loss != 0 only for routing attention) """ # compute k, q, v d_k, d_v, n_head = self.d_k, self.d_v, self.n_head q = view_range(self.w_qs(q), -1, None, (n_head, d_k)) k = view_range(self.w_ks(k), -1, None, (n_head, d_k)) v = view_range(self.w_vs(v), -1, None, (n_head, d_v)) # b x n_head x seq_len x d # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d) q = shift_dim(q, -2, 1) k = shift_dim(k, -2, 1) v = shift_dim(v, -2, 1) # axial transformer does not use this caching if decode_step is not None: # create cache if first iter of sampling if self.cache is None: if self.causal: k_shape = ( q.shape[0], n_head, ) + self.shape + (self.d_k, ) v_shape = ( q.shape[0], n_head, ) + self.shape + (self.d_v, ) self.cache = dict(k=torch.zeros(k_shape, dtype=k.dtype, device=q.device), v=torch.zeros(v_shape, dtype=v.dtype, device=q.device)) else: # in the non-causal case only need to cache once self.cache = dict(k=k.clone(), v=v.clone()) if self.causal: idx = (slice(None, None), slice(None, None)) + \ tuple([slice(i, i + 1) for i in decode_idx]) self.cache['k'][idx] = k self.cache['v'][idx] = v k, v = self.cache['k'], self.cache['v'] a = self.attn(q, k, v, decode_step, decode_idx) # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d) a = shift_dim(a, 1, -2).flatten(start_dim=-2) a = self.fc(a) # (b x seq_len x embd_dim) return a
def sample(self, n, codebook, cond, device, temperature, no_flatten, is_root): if is_root and os.environ.get('VERBOSE') == '1': print( f"Need {n} samples, MAX_SAMPLER_PER_BATCH = {MAX_SAMPLES_PER_BATCH}" ) samples = torch.zeros((n, ) + self.shape).long().to(device) assert all( n == c.shape[0] for c in cond), f"cond shapes {[c.shape for c in cond]}, n, {n}" for i in range(0, samples.shape[0], MAX_SAMPLES_PER_BATCH): if is_root: pbar = tqdm(total=np.prod(self.shape)) samples_subset = samples[i:i + MAX_SAMPLES_PER_BATCH] cond_subset = tuple([c[i:i + MAX_SAMPLES_PER_BATCH] for c in cond]) with torch.no_grad(), self.sample_mode(): prev_idx = None for j, idx in enumerate(self.sample_order()): # idx must be a tuple, and not a list # pytorch tensor indexing is different when using list vs tuple # tuple is indexing, list is gather batch_idx_slice = (slice(None, None), ) + tuple( [slice(i, i + 1) for i in idx]) batch_idx = (slice(None, None), ) + idx quantized = codebook.dictionary_lookup(samples_subset, no_flatten=True) quantized = shift_dim(quantized, 1, -1) if prev_idx is None: s_inp = samples_subset[ batch_idx_slice] # doesn't really matter what it is q_inp = torch.zeros_like(quantized[batch_idx_slice]) s_inp.q_inp = s_inp.to(device), q_inp.to(device) else: s_inp, q_inp = samples_subset[prev_idx], quantized[ prev_idx] logits = self(quantized=q_inp, encodings=s_inp, cond=cond_subset, decode_step=j, decode_idx=idx)['gen_logits'] probs = F.softmax(logits / temperature, dim=-1) if probs.shape[0] == 1: probs = probs.squeeze().unsqueeze(0) else: probs = probs.squeeze() samples_subset[batch_idx] = torch.multinomial( probs, 1).squeeze(-1) prev_idx = batch_idx_slice if is_root: pbar.update(1) if os.environ.get('DEBUG') == '1': break if is_root: pbar.close() samples[i:i + MAX_SAMPLES_PER_BATCH] = samples_subset if os.environ.get('DEBUG') == '1': break assert samples.shape[0] == n encodings = samples quantized = codebook.dictionary_lookup(encodings, no_flatten=no_flatten) return quantized, encodings
def main_worker(rank, size, args_in): global args args = args_in is_root = rank == 0 dist.init_process_group(backend='nccl', init_method=f'tcp://localhost:{args.port}', world_size=size, rank=rank) assert args.n_samples % size == 0, f'n_samples {args.n_samples} not divisible by size {size}' seed = args.seed + rank seed_all(seed) device = config_device() prior_ckpt = torch.load(get_ckpt(args.prior_ckpt), map_location=device) vqvae_ckpt = torch.load(get_ckpt(prior_ckpt['vqvae_ckpt']), map_location=device) """ Load datasets """ dset_configs = prior_ckpt['dset_configs'] cond_hp = prior_ckpt['cond_hp'] train_loader, test_loader, dset = get_distributed_loaders( dset_configs=dset_configs, batch_size=args.n_samples, seed=seed) loader = test_loader # shuffle the dataset according to some fixed seed loader.sampler.set_epoch(seed) batch = next( iter(loader) ) # get batch as early as possible for fixed seed sampling examples vqvae, vq_hp = load_model(ckpt=vqvae_ckpt, device=device, freeze_model=True, cond_types=()) def load_layer_prior(ckpt): # must use the same self_gen_types for vae and all prior layers cond_types, cond_hp = config_cond_types(cond_hp=ckpt['cond_hp'], dset=dset) # freeze all previous priors, not the current one prior, hp = load_model(ckpt=ckpt, device=device, freeze_model=True, cond_types=cond_types) codebook = vqvae.codebook return prior, hp, codebook latent_shape = vqvae.latent_shape quantized_shape = vqvae.quantized_shape if is_root: print('latent shapes', latent_shape) print('quantized shape', quantized_shape) print('total latents', np.prod(latent_shape)) prior, prior_hp, codebook = load_layer_prior(prior_ckpt) if is_root: print( f"Loaded vqvae at iteration {vqvae_ckpt['iteration']}, loss = {vqvae_ckpt['best_loss']}" ) print( f"Loaded GPT at iteration {prior_ckpt['iteration']}, loss {prior_ckpt['best_loss']}" ) """ Generate samples """ sample_fn = functools.partial(sample, cond_hp=cond_hp, vae=vqvae, prior=prior, codebook=codebook, device=device, temperature=args.temperature, rank=rank, size=size) gathered_samples, gathered_cond = sample_fn(n_samples=args.n_samples, batch=batch, gather=True) if is_root: os.makedirs(args.output_dir, exist_ok=True) # (n, c, t, h, w) -> (n, t, c, h, w) samples = shift_dim(gathered_samples, 2, 1) T = samples.shape[1] save_image(samples.flatten(end_dim=1), osp.join(args.output_dir, "samples.png"), nrow=T) # (n, t, c, h, w) -> (n, t, h, w, c) samples = shift_dim(samples, 2, -1) samples = (samples.cpu().numpy() * 255).astype('uint8') for i in range(min(args.n_samples, MAX_N_SAMPLES_TO_VIDEO)): skvideo.io.vwrite(osp.join(args.output_dir, f'samples_{i}.mp4'), samples[i], inputdict={'-r': '5'}) np.save(osp.join(args.output_dir, 'samples.npy'), samples) print('outputted videos to:', args.output_dir)
def forward(self, x): x = shift_dim(x, 1, -1) x = self.norm(x) x = shift_dim(x, -1, 1) return x