def sample_recurrent(self, n_samples, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, get_preds=False, sample_tokens=None): assert self.training == False memory = None if sample_tokens is None: sample_tokens = self.input_dims N, D = n_samples, self.input_dims if self.y_cond: assert y_cond is not None assert y_cond.shape == (N, 1, self.width) else: assert y_cond is None if self.x_cond: assert x_cond is not None assert x_cond.shape == (N, D, self.width) or x_cond.shape == ( N, 1, self.width ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})" else: assert x_cond is None x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda() with t.no_grad(): xs, x = [], None if get_preds: preds = [] for sample_t in get_range(range(0, sample_tokens)): x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) #self.transformer.check_cache(n_samples, sample_t, fp16) x, memory = self.transformer(x[:, -1, :], memory) # Transformer x = t.unsqueeze(x, 1) if self.add_cond_after_transformer: x = x + cond assert x.shape == (n_samples, 1, self.width) x = self.x_out(x) # Predictions if get_preds: preds.append(x.clone()) # Adjust logits x = x / temp x = filter_logits(x, top_k=top_k, top_p=top_p) x = t.distributions.Categorical( logits=x).sample() # Sample and replace x assert x.shape == (n_samples, 1) xs.append(x.clone()) del x #self.transformer.del_cache() x = t.cat(xs, dim=1) if get_preds: preds = t.cat(preds, dim=1) x = self.postprocess(x, sample_tokens) if get_preds: return x, preds else: return x, memory
def primed_sample(self, n_samples, x, memory, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, get_preds=False, chunk_size=None, sample_tokens=None): assert self.training == False if sample_tokens is None: sample_tokens = self.input_dims # Preprocess. with t.no_grad(): x = self.preprocess(x) assert isinstance(x, t.cuda.LongTensor) assert (0 <= x).all() and (x < self.bins).all() assert x.shape[0] == n_samples xs = t.split(x, 1, dim=1) xs = list(xs) assert len(xs) < sample_tokens N, D = n_samples, self.input_dims if self.y_cond: assert y_cond is not None assert y_cond.shape == (N, 1, self.width) else: assert y_cond is None if self.x_cond: assert x_cond is not None assert x_cond.shape == (N, D, self.width) or x_cond.shape == ( N, 1, self.width ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})" else: assert x_cond is None x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda() with t.no_grad(): if get_preds: preds = [] x = xs[-1] assert x.shape == (n_samples, 1) empty_cache() for sample_t in get_range(range(len(xs), sample_tokens)): x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) #self.transformer.check_cache(n_samples, sample_t, fp16) x, memory = self.transformer(x[:, -1, :], memory) # Transformer x = t.unsqueeze(x, 1) if self.add_cond_after_transformer: x = x + cond assert x.shape == (n_samples, 1, self.width) x = self.x_out(x) # Predictions if get_preds: preds.append(x) # Adjust logits x = x / temp x = filter_logits(x, top_k=top_k, top_p=top_p) x = t.distributions.Categorical( logits=x).sample() # Sample and replace x assert x.shape == (n_samples, 1) xs.append(x.clone()) del x #self.transformer.del_cache() x = t.cat(xs, dim=1) if get_preds: preds = t.cat(preds, dim=1) x = self.postprocess(x, sample_tokens) if get_preds: return x, preds else: return x, memory
def primed_sample(self, n_samples, x, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, get_preds=False, chunk_size=None, sample_tokens=None): assert self.training == False if sample_tokens is None: sample_tokens = self.input_dims # Preprocess. with t.no_grad(): x = self.preprocess(x) assert isinstance(x, t.cuda.LongTensor) assert (0 <= x).all() and (x < self.bins).all() assert x.shape[0] == n_samples xs = t.split(x, 1, dim=1) xs = list(xs) assert len(xs) < sample_tokens N, D = n_samples, self.input_dims if self.y_cond: assert y_cond is not None assert y_cond.shape == (N, 1, self.width) else: assert y_cond is None if self.x_cond: assert x_cond is not None assert x_cond.shape == (N, D, self.width) or x_cond.shape == ( N, 1, self.width ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})" else: assert x_cond is None x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda() with t.no_grad(): if get_preds: preds = [] # Fill up key/value cache for past context by runing forward pass. # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage. if chunk_size is None: chunk_size = len(xs) #assert len(xs) % chunk_size == 0, f'expected {len(xs)} to be divisible by {chunk_size}' chunk_sizes = split_chunks(len(xs), chunk_size) x_primes = [] start = 0 x = None for current_chunk_size in get_range(chunk_sizes): xs_prime, conds_prime = [], [] for sample_t in range(start, start + current_chunk_size): x_prime, cond_prime = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) x = xs[sample_t] xs_prime.append(x_prime) conds_prime.append(cond_prime) start = start + current_chunk_size x_prime, cond_prime = t.cat(xs_prime, dim=1), t.cat(conds_prime, dim=1) assert x_prime.shape == (n_samples, current_chunk_size, self.width) assert cond_prime.shape == (n_samples, current_chunk_size, self.width) del xs_prime del conds_prime if not get_preds: del cond_prime x_prime = self.transformer(x_prime, encoder_kv=encoder_kv, sample=True, fp16=fp16) if get_preds: if self.add_cond_after_transformer: x_prime = x_prime + cond_prime assert x_prime.shape == (n_samples, current_chunk_size, self.width) del cond_prime x_primes.append(x_prime) else: del x_prime if get_preds: x_prime = t.cat(x_primes, dim=1) assert x_prime.shape == (n_samples, len(xs), self.width) x_prime = self.x_out(x_prime) # Predictions preds.append(x_prime) empty_cache() self.transformer.check_cache(n_samples, len(xs), fp16) x = xs[-1] assert x.shape == (n_samples, 1) empty_cache() for sample_t in get_range(range(len(xs), sample_tokens)): x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) self.transformer.check_cache(n_samples, sample_t, fp16) x = self.transformer(x, encoder_kv=encoder_kv, sample=True, fp16=fp16) # Transformer if self.add_cond_after_transformer: x = x + cond assert x.shape == (n_samples, 1, self.width) x = self.x_out(x) # Predictions if get_preds: preds.append(x) # Adjust logits x = x / temp x = filter_logits(x, top_k=top_k, top_p=top_p) x = t.distributions.Categorical( logits=x).sample() # Sample and replace x assert x.shape == (n_samples, 1) xs.append(x.clone()) del x self.transformer.del_cache() x = t.cat(xs, dim=1) if get_preds: preds = t.cat(preds, dim=1) x = self.postprocess(x, sample_tokens) if get_preds: return x, preds else: return x