class VAE(object): def __init__(self, inpt, latent, reconstruction, likelihood): # Create self.inpt, self.latent, self.reconstruction, and self.likelihood self.__dict__.update(locals()) # 'Model' is a trainable keras object. self.model = Model(inpt, reconstruction) # To maximize ELBO, keras will minimize "loss" of -ELBO self.model.add_loss(-self.elbo()) def elbo(self): flat_input = K.batch_flatten(self.inpt) # LL term is E_q(z|x) [ log p(x|z) ] and has shape (batch,) self.ll = self.likelihood.log_prob(flat_input) # KL term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch,) try: # Use analytic KL if it is available, or fall back on using sample KL self.kl = self.latent.analytic_kl() except TypeError: self.kl = self.latent.sample_kl() # ELBO simply (LL - KL) and has shape (batch,) return self.ll - self.kl
class IWAE(object): def __init__(self, inpt, latents, reconstruction, likelihood, k_samples, kl_weight=1): # noqa:E501 self.__dict__.update(locals()) # 'Model' is a trainable keras object self.model = Model(inpt, reconstruction) self.model.add_loss(self.elbo_loss()) def set_samples(self, k): # TODO - use learning_phase() to set to 1 during testing? K.set_value(self.k_samples, k) def elbo_loss(self): # Repeat inputs once per latent sample to size (batch, samples, -1), where -1 stands for # 'all subsequent dimensions flattened'. input_shape = K.shape(self.inpt) repeated_input = K.repeat(K.batch_flatten(self.inpt), self.k_samples) # NLL loss term is E_q(z|x) [ -log p(x|z) ] and has shape (batch, samples) batch_sample_shape = K.stack([input_shape[0], self.k_samples, -1]) batch_sample_reconstruction = K.reshape(self.reconstruction, batch_sample_shape) self.nll = self.likelihood.nll(repeated_input, batch_sample_reconstruction) # Each KL loss term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch, samples) kl_losses = K.sum([latent.sample_kl() for latent in self.latents], axis=0) # Total loss is simply sum of KL and NLL terms and has shape (batch, samples) total_loss = self.kl_weight * kl_losses + self.nll # Final loss is weighted sum across k samples. More precisely, the total gradient is a # weighted sum of sample gradients. K.stop_gradient() is used to make the weights act on # gradients and not provide gradients themselves (weights are not 'learned' per se). # Weights have shape (batch, samples). weights = K.stop_gradient(self._get_weights()) # Final loss per input is a weighted sum of sample losses return K.sum(total_loss * weights, axis=-1) def _get_weights(self): log_likelihood = -self.nll log_p = K.sum([q.prior.log_prob(q.samples) for q in self.latents], axis=0) log_q = K.sum([q.log_prob(q.samples) for q in self.latents], axis=0) log_weights = log_likelihood + log_p - log_q log_weights -= K.logsumexp(log_weights, axis=-1, keepdims=True) weights_unnormalized = K.exp(log_weights) return weights_unnormalized / K.sum( weights_unnormalized, axis=-1, keepdims=True)
class VAE(object): def __init__(self, inpt, latent, reconstruction, likelihood): # Create self.inpt, self.latent, self.reconstruction, and self.likelihood self.__dict__.update(locals()) # 'Model' is a trainable keras object. self.model = Model(inpt, reconstruction) # To maximize ELBO, keras will minimize "loss" of -ELBO self.model.add_loss(-self.elbo()) def elbo(self):
class VAE(object): def __init__(self, inpt, latent, reconstruction, likelihood, k_samples): # Create self.inpt, self.latent, self.reconstruction, and self.likelihood self.__dict__.update(locals()) # 'Model' is a trainable keras object. self.model = Model(inpt, reconstruction) self.model.add_loss(-self.elbo()) def set_samples(self, k): K.set_value(self.k_samples, k) def elbo(self):
class VAE(object): def __init__(self, inpt, latent, reconstruction, likelihood, k_samples): # Create self.inpt, self.latent, self.reconstruction, and self.likelihood self.__dict__.update(locals()) # 'Model' is a trainable keras object. self.model = Model(inpt, reconstruction) self.model.add_loss(-self.elbo()) def set_samples(self, k): K.set_value(self.k_samples, k) def elbo(self): batch = K.shape(self.inpt)[0] # shape of flat_input is (batch, pixels) flat_input = K.batch_flatten(self.inpt) # shape of repeated_input is (batch, samples, pixels) repeated_input = K.repeat(flat_input, self.k_samples) # shape of flat_repeated_input is (batch * samples, pixels) to match the shape of self.reconstruction flat_repeated_input = K.reshape(repeated_input, (batch * self.k_samples, -1)) # LL term is E_q(z|x) [ log p(x|z) ] (Note that mean over k_samples happens later) # shape of flat_ll is (batch * samples,). flat_ll = self.likelihood.log_prob(flat_repeated_input) # shape of self.ll is (batch, samples) self.ll = K.reshape(flat_ll, (batch, -1)) # KL term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch,) if analytic or # (batch, samples) otherwise try: # Use analytic KL if it is available, which has shape (batch,) self.kl = self.latent.analytic_kl() return K.sum(self.ll, axis=-1) / K.cast(self.k_samples, 'float32') - self.kl except TypeError: # If analytic KL is not available, fall back on sample KL. self.kl = self.latent.sample_kl() # ELBO is mean-over-samples of (LL - KL) and has shape (batch,) return K.sum(self.ll - self.kl, axis=-1) / K.cast( self.k_samples, 'float32')
class VAE(object): def __init__(self, inpt, latent, reconstruction, likelihood): # Create self.inpt, self.latent, self.reconstruction, and self.likelihood self.__dict__.update(locals()) # 'Model' is a trainable keras object. self.model = Model(inpt, reconstruction) # To maximize ELBO, keras will minimize "loss" of -ELBO self.model.add_loss(-self.elbo()) def elbo(self): flat_input = K.batch_flatten(self.inpt) # LL term is E_q(z|x) [ log p(x|z) ] and has shape (batch,) self.ll = self.likelihood.log_prob(flat_input) # YOUR CODE HERE # self.kl = ... # ELBO simply (LL - KL) and has shape (batch,) return self.ll - self.kl
class Transformer(object): def __init__(self, src_dict, tar_dict=None, length_limit=70, num_layers=6, model_dim=512, num_head=8, head_dim=None, inner_dim=2048, dropout=0.1, use_pos_embedding=True, share_embedding=False, inputs=None, outputs=None, name="Transformer"): if inputs is not None and outputs is not None: super(Transformer, self).__init__(inputs=inputs, outputs=outputs, name=name) return self.src_dict = src_dict self.tar_dict = tar_dict if tar_dict is not None else self.src_dict self.src_token_dict = {v: k for k, v in self.src_dict.items()} self.tar_token_dict = {v: k for k, v in self.tar_dict.items()} self.scr_dict_size = len(self.src_dict) self.tar_dict_size = len( self.tar_dict) if tar_dict is not None else self.scr_dict_size self.length_limit = length_limit self.num_layers = num_layers self.num_head = num_head self.model_dim = model_dim self.head_dim = head_dim if head_dim is not None else int(model_dim / num_head) self.inner_dim = inner_dim self.dropout = dropout self.use_pos_embedding = use_pos_embedding self.share_embedding = share_embedding self.source_embedding = None self.target_embedding = None self.position_embedding = None self.encoder = None self.decoder = None self.softmax = None self.model = None self.output_model = None self.decode_build = False self.encoder_model = None self.decoder_model = None def compile(self, optimizer="adam"): source_input = Input(shape=(None, ), dtype="int32") target_input = Input(shape=(None, ), dtype="int32") target_decode_in = Lambda(lambda x: K.slice( x, start=[0, 0], size=[K.shape(target_input)[0], K.shape(target_input)[1] - 1]))(target_input) target_decode_out = Lambda(lambda x: K.slice( x, start=[0, 1], size=[K.shape(target_input)[0], K.shape(target_input)[1] - 1]))(target_input) src_mask = Lambda(lambda x: get_mask_seq2seq(x, x))(source_input) tar_mask = Lambda(lambda x: self.get_self_mask(x))(target_decode_in) encode_mask = Lambda(lambda x: get_mask_seq2seq(x[0], x[1]))( [target_decode_in, source_input]) self.source_embedding = Embedding(input_dim=self.scr_dict_size, output_dim=self.model_dim) if self.share_embedding: self.target_embedding = self.source_embedding else: self.target_embedding = Embedding(input_dim=self.tar_dict_size, output_dim=self.model_dim) if self.use_pos_embedding: self.position_embedding = PositionEmbedding(mode="sum") src_x = self.source_embedding(source_input) if self.use_pos_embedding: src_x = self.position_embedding(src_x) src_x = Dropout(self.dropout)(src_x) self.encoder = Encode(num_layers=self.num_layers, num_head=self.num_head, head_dim=self.head_dim, model_dim=self.model_dim, inner_dim=self.inner_dim, dropout=self.dropout) encoder_output = self.encoder(src_x, masks=src_mask) tar_x = self.target_embedding(target_decode_in) if self.use_pos_embedding: tar_x = self.position_embedding(tar_x) self.decoder = Decode(num_layers=self.num_layers, num_head=self.num_head, head_dim=self.head_dim, model_dim=self.model_dim, inner_dim=self.inner_dim, dropout=self.dropout) decoder_output = self.decoder([tar_x, encoder_output], self_mask=tar_mask, encode_mask=encode_mask) self.softmax = TimeDistributed(Dense(self.tar_dict_size)) output = self.softmax(decoder_output) loss = Lambda(lambda x: self._get_loss(*x))( [output, target_decode_out]) self.model = Model([source_input, target_input], loss) self.model.add_loss([loss]) self.model.compile(optimizer, None) self.model.metrics_names.append("ppl") self.model.metrics_tensors.append(Lambda(K.exp)(loss)) self.model.metrics_names.append("accuracy") self.model.metrics_tensors.append( Lambda(lambda x: self._get_acc(x[0], x[1]))( [output, target_decode_out])) self.output_model = Model([source_input, target_input], output) @staticmethod def get_encode_mask(src_seq): return get_mask_seq2seq(src_seq, src_seq) @staticmethod def get_self_mask(tar_seq): self_mask1 = get_mask_seq2seq(tar_seq, tar_seq) self_mask2 = get_mask_self(tar_seq) return K.minimum(self_mask1, self_mask2) @staticmethod def _get_loss(y_pred, y_true): y_true = tf.cast(y_true, dtype="int32") loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred) mask = tf.cast(tf.not_equal(y_true, 0), dtype="float32") loss = tf.reduce_sum(loss * mask, -1) / tf.reduce_sum(mask, -1) return tf.reduce_mean(loss) @staticmethod def _get_acc(y_pred, y_true): mask = tf.cast(tf.not_equal(y_true, 0), dtype="float32") corr = K.cast(K.equal(K.cast(y_true, dtype="int32"), K.cast(K.argmax(y_pred, -1), dtype="int32")), dtype="float32") acc = K.sum(corr * mask, -1) / K.sum(mask, -1) return K.mean(acc) def decode_fast(self, seq): decode_tokens = [] target_seq = np.zeros(shape=(1, self.length_limit), dtype=np.int32) target_seq[0, 0] = 2 for i in range(self.length_limit - 1): output = self.output_model.predict_on_batch([seq, target_seq]) max_prob_index = np.argmax(output[0, i, :]) max_prob_token = self.tar_token_dict[max_prob_index] decode_tokens.append(max_prob_token) if max_prob_index == 3: break target_seq[0, i + 1] = max_prob_index return " ".join(decode_tokens) def _build_encoder(self): source_input = Input(shape=(None, ), dtype="int32") src_mask = Lambda(lambda x: get_mask_seq2seq(x, x))(source_input) src_x = self.source_embedding(source_input) if self.use_pos_embedding: src_x = self.position_embedding(src_x) encoder_output = self.encoder(src_x, masks=src_mask) self.encoder_model = Model([source_input], encoder_output) self.encoder_model.compile('adam', 'mse') def _build_decoder(self): source_input = Input(shape=(None, ), dtype="int32") target_input = Input(shape=(None, ), dtype="int32") encoder_output = Input(shape=(None, self.model_dim)) tar_mask = Lambda(lambda x: self.get_self_mask(x))(target_input) encode_mask = Lambda(lambda x: get_mask_seq2seq(x[0], x[1]))( [target_input, source_input]) tar_x = self.target_embedding(target_input) if self.use_pos_embedding: tar_x = self.position_embedding(tar_x) decoder_output = self.decoder([tar_x, encoder_output], self_mask=tar_mask, encode_mask=encode_mask) final_output = self.softmax(decoder_output) self.decoder_model = Model( [source_input, target_input, encoder_output], final_output) self.decoder_model.compile('adam', 'mse') def _build_decode_model(self): self._build_encoder() self._build_decoder() self.decode_build = True def decode(self, seq): if not self.decode_build: self._build_decode_model() decode_tokens = [] target_seq = np.zeros(shape=(1, self.length_limit), dtype=np.int32) target_seq[0, 0] = 2 encoder_output = self.encoder_model.predict_on_batch([seq]) for i in range(self.length_limit - 1): output = self.decoder_model.predict_on_batch( [seq, target_seq, encoder_output]) max_prob_index = np.argmax(output[0, i, :]) max_prob_token = self.tar_token_dict[max_prob_index] decode_tokens.append(max_prob_token) if max_prob_index == 3: break target_seq[0, i + 1] = max_prob_index return " ".join(decode_tokens) def beam_search(self, seq, topk=3): if not self.decode_build: self._build_decode_model() seq = np.repeat(seq, topk, axis=0) encoder_output = self.encoder_model.predict_on_batch([seq]) final_results = [] topk_prob = np.zeros((topk, ), dtype=np.float32) decode_tokens = [[] for _ in range(topk)] target_seq = np.zeros((topk, self.length_limit), dtype=np.int32) target_seq[:, 0] = 2 last_k = 1 for i in range(self.length_limit - 1): if last_k == 0 or len(final_results) > topk * 3: break # stop conditions target_output = self.decoder_model.predict_on_batch( [seq, target_seq, encoder_output]) output = np.exp(target_output[:, i, :]) output = output / np.sum(output, axis=-1, keepdims=True) output = np.log( output + 1e-8) # use `log` transformation to avoid tiny probability candidates = [] for k, probs in zip(range(last_k), output): if target_seq[k, i] == 3: continue word_p_sort = sorted(list(enumerate(probs)), key=lambda x: x[1], reverse=True) for ind, wp in word_p_sort[:topk]: candidates.append((k, ind, topk_prob[k] + wp)) candidates = sorted(candidates, key=lambda x: x[-1], reverse=True) candidates = candidates[:topk] target_seq_bk = target_seq.copy() for new_k, cand in enumerate(candidates): k, ind, seq_p = cand target_seq[new_k] = target_seq_bk[k] target_seq[new_k, i + 1] = ind topk_prob[new_k] = seq_p decode_tokens.append(decode_tokens[k] + [self.tar_token_dict[ind]]) if ind == 3: final_results.append((decode_tokens[k], seq_p)) decode_tokens = decode_tokens[topk:] last_k = len(decode_tokens) final_results = [(x, y / (len(x) + 1)) for x, y in final_results] final_results = sorted(final_results, key=lambda x: x[1], reverse=True) return final_results
class IWAE(object): def __init__(self, inpt, latent, reconstruction, likelihood, k_samples): # Create self.inpt, self.latent, self.reconstruction, and self.likelihood self.__dict__.update(locals()) # 'Model' is a trainable keras object. self.model = Model(inpt, reconstruction) self.model.add_loss(-self.elbo()) def set_samples(self, k): K.set_value(self.k_samples, k) def elbo(self): batch = K.shape(self.inpt)[0] # shape of flat_input is (batch, pixels) flat_input = K.batch_flatten(self.inpt) # shape of repeated_input is (batch, samples, pixels) repeated_input = K.repeat(flat_input, self.k_samples) # shape of flat_repeated_input is (batch * samples, pixels) to match the shape of self.reconstruction flat_repeated_input = K.reshape(repeated_input, (batch * self.k_samples, -1)) # LL term is E_q(z|x) [ log p(x|z) ] (Note that mean over k_samples happens later) # shape of flat_ll is (batch * samples,). flat_ll = self.likelihood.log_prob(flat_repeated_input) # shape of self.ll is (batch, samples) self.ll = K.reshape(flat_ll, (batch, -1)) # Final loss is weighted sum across k samples. More precisely, the total gradient is a # weighted sum of sample gradients. K.stop_gradient() is used to make the weights act on # gradients and not provide gradients themselves (weights are not 'learned' per se). # Weights have shape (batch, samples). weights = K.stop_gradient(self._get_weights()) # KL term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch,) if analytic or # (batch, samples) otherwise try: # Use analytic KL if it is available, which has shape (batch,) self.kl = self.latent.analytic_kl() return K.sum(weights * self.ll, axis=-1) - self.kl except TypeError: # If analytic KL is not available, fall back on sample KL. self.kl = self.latent.sample_kl() # ELBO is mean-over-samples of (LL - KL) and has shape (batch,) return K.sum(weights * (self.ll - self.kl), axis=-1) def _get_weights(self): # IWAE sample weight on sample i is p(x,latent_i)/q(latent_i|x). Weights are then # normalized to sum to 1. First computing log-weights is more numerically stable. log_p = self.latent.prior.log_prob(self.latent.samples) + self.ll log_q = self.latent.log_prob(self.latent.samples) log_weights = log_p - log_q # Pre-normalize results in log space by subtracting logsumexp (which is dividing by sum in # probability space). This keeps results stable so that the following call to exp() is # given values in a reasonable range. Results may not sum to exactly 1, though, due to # floating point precision. log_weights -= K.logsumexp(log_weights, axis=-1, keepdims=True) # Get out of log space and normalize a second time since logsumexp is not perfect. weights_unnormalized = K.exp(log_weights) return weights_unnormalized / K.sum( weights_unnormalized, axis=-1, keepdims=True)