Ejemplo n.º 1
0
class VAE(object):
    def __init__(self, inpt, latent, reconstruction, likelihood):
        # Create self.inpt, self.latent, self.reconstruction, and self.likelihood
        self.__dict__.update(locals())

        # 'Model' is a trainable keras object.
        self.model = Model(inpt, reconstruction)
        # To maximize ELBO, keras will minimize "loss" of -ELBO
        self.model.add_loss(-self.elbo())

    def elbo(self):
        flat_input = K.batch_flatten(self.inpt)

        # LL term is E_q(z|x) [ log p(x|z) ] and has shape (batch,)
        self.ll = self.likelihood.log_prob(flat_input)

        # KL term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch,)
        try:
            # Use analytic KL if it is available, or fall back on using sample KL
            self.kl = self.latent.analytic_kl()
        except TypeError:
            self.kl = self.latent.sample_kl()

        # ELBO simply (LL - KL) and has shape (batch,)
        return self.ll - self.kl
Ejemplo n.º 2
0
class IWAE(object):
    def __init__(self,
                 inpt,
                 latents,
                 reconstruction,
                 likelihood,
                 k_samples,
                 kl_weight=1):  # noqa:E501
        self.__dict__.update(locals())

        # 'Model' is a trainable keras object
        self.model = Model(inpt, reconstruction)
        self.model.add_loss(self.elbo_loss())

    def set_samples(self, k):
        # TODO - use learning_phase() to set to 1 during testing?
        K.set_value(self.k_samples, k)

    def elbo_loss(self):
        # Repeat inputs once per latent sample to size (batch, samples, -1), where -1 stands for
        # 'all subsequent dimensions flattened'.
        input_shape = K.shape(self.inpt)
        repeated_input = K.repeat(K.batch_flatten(self.inpt), self.k_samples)

        # NLL loss term is E_q(z|x) [ -log p(x|z) ] and has shape (batch, samples)
        batch_sample_shape = K.stack([input_shape[0], self.k_samples, -1])
        batch_sample_reconstruction = K.reshape(self.reconstruction,
                                                batch_sample_shape)
        self.nll = self.likelihood.nll(repeated_input,
                                       batch_sample_reconstruction)

        # Each KL loss term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch, samples)
        kl_losses = K.sum([latent.sample_kl() for latent in self.latents],
                          axis=0)

        # Total loss is simply sum of KL and NLL terms and has shape (batch, samples)
        total_loss = self.kl_weight * kl_losses + self.nll

        # Final loss is weighted sum across k samples. More precisely, the total gradient is a
        # weighted sum of sample gradients. K.stop_gradient() is used to make the weights act on
        # gradients and not provide gradients themselves (weights are not 'learned' per se).
        # Weights have shape (batch, samples).
        weights = K.stop_gradient(self._get_weights())

        # Final loss per input is a weighted sum of sample losses
        return K.sum(total_loss * weights, axis=-1)

    def _get_weights(self):
        log_likelihood = -self.nll
        log_p = K.sum([q.prior.log_prob(q.samples) for q in self.latents],
                      axis=0)
        log_q = K.sum([q.log_prob(q.samples) for q in self.latents], axis=0)
        log_weights = log_likelihood + log_p - log_q
        log_weights -= K.logsumexp(log_weights, axis=-1, keepdims=True)
        weights_unnormalized = K.exp(log_weights)
        return weights_unnormalized / K.sum(
            weights_unnormalized, axis=-1, keepdims=True)
Ejemplo n.º 3
0
class VAE(object):
    def __init__(self, inpt, latent, reconstruction, likelihood):
        # Create self.inpt, self.latent, self.reconstruction, and self.likelihood
        self.__dict__.update(locals())

        # 'Model' is a trainable keras object.
        self.model = Model(inpt, reconstruction)
        # To maximize ELBO, keras will minimize "loss" of -ELBO
        self.model.add_loss(-self.elbo())

    def elbo(self):
Ejemplo n.º 4
0
class VAE(object):
    def __init__(self, inpt, latent, reconstruction, likelihood, k_samples):
        # Create self.inpt, self.latent, self.reconstruction, and self.likelihood
        self.__dict__.update(locals())

        # 'Model' is a trainable keras object.
        self.model = Model(inpt, reconstruction)
        self.model.add_loss(-self.elbo())

    def set_samples(self, k):
        K.set_value(self.k_samples, k)

    def elbo(self):
Ejemplo n.º 5
0
class VAE(object):
    def __init__(self, inpt, latent, reconstruction, likelihood, k_samples):
        # Create self.inpt, self.latent, self.reconstruction, and self.likelihood
        self.__dict__.update(locals())

        # 'Model' is a trainable keras object.
        self.model = Model(inpt, reconstruction)
        self.model.add_loss(-self.elbo())

    def set_samples(self, k):
        K.set_value(self.k_samples, k)

    def elbo(self):
        batch = K.shape(self.inpt)[0]
        # shape of flat_input is (batch, pixels)
        flat_input = K.batch_flatten(self.inpt)
        # shape of repeated_input is (batch, samples, pixels)
        repeated_input = K.repeat(flat_input, self.k_samples)
        # shape of flat_repeated_input is (batch * samples, pixels) to match the shape of self.reconstruction
        flat_repeated_input = K.reshape(repeated_input,
                                        (batch * self.k_samples, -1))

        # LL term is E_q(z|x) [ log p(x|z) ] (Note that mean over k_samples happens later)
        # shape of flat_ll is (batch * samples,).
        flat_ll = self.likelihood.log_prob(flat_repeated_input)
        # shape of self.ll is (batch, samples)
        self.ll = K.reshape(flat_ll, (batch, -1))

        # KL term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch,) if analytic or
        # (batch, samples) otherwise
        try:
            # Use analytic KL if it is available, which has shape (batch,)
            self.kl = self.latent.analytic_kl()

            return K.sum(self.ll, axis=-1) / K.cast(self.k_samples,
                                                    'float32') - self.kl
        except TypeError:
            # If analytic KL is not available, fall back on sample KL.
            self.kl = self.latent.sample_kl()

            # ELBO is mean-over-samples of (LL - KL) and has shape (batch,)
            return K.sum(self.ll - self.kl, axis=-1) / K.cast(
                self.k_samples, 'float32')
Ejemplo n.º 6
0
class VAE(object):
    def __init__(self, inpt, latent, reconstruction, likelihood):
        # Create self.inpt, self.latent, self.reconstruction, and self.likelihood
        self.__dict__.update(locals())

        # 'Model' is a trainable keras object.
        self.model = Model(inpt, reconstruction)
        # To maximize ELBO, keras will minimize "loss" of -ELBO
        self.model.add_loss(-self.elbo())

    def elbo(self):
        flat_input = K.batch_flatten(self.inpt)

        # LL term is E_q(z|x) [ log p(x|z) ] and has shape (batch,)
        self.ll = self.likelihood.log_prob(flat_input)

        # YOUR CODE HERE
        # self.kl = ...

        # ELBO simply (LL - KL) and has shape (batch,)
        return self.ll - self.kl
Ejemplo n.º 7
0
class Transformer(object):
    def __init__(self,
                 src_dict,
                 tar_dict=None,
                 length_limit=70,
                 num_layers=6,
                 model_dim=512,
                 num_head=8,
                 head_dim=None,
                 inner_dim=2048,
                 dropout=0.1,
                 use_pos_embedding=True,
                 share_embedding=False,
                 inputs=None,
                 outputs=None,
                 name="Transformer"):
        if inputs is not None and outputs is not None:
            super(Transformer, self).__init__(inputs=inputs,
                                              outputs=outputs,
                                              name=name)
            return

        self.src_dict = src_dict
        self.tar_dict = tar_dict if tar_dict is not None else self.src_dict
        self.src_token_dict = {v: k for k, v in self.src_dict.items()}
        self.tar_token_dict = {v: k for k, v in self.tar_dict.items()}
        self.scr_dict_size = len(self.src_dict)
        self.tar_dict_size = len(
            self.tar_dict) if tar_dict is not None else self.scr_dict_size

        self.length_limit = length_limit

        self.num_layers = num_layers
        self.num_head = num_head
        self.model_dim = model_dim
        self.head_dim = head_dim if head_dim is not None else int(model_dim /
                                                                  num_head)
        self.inner_dim = inner_dim

        self.dropout = dropout
        self.use_pos_embedding = use_pos_embedding
        self.share_embedding = share_embedding

        self.source_embedding = None
        self.target_embedding = None
        self.position_embedding = None
        self.encoder = None
        self.decoder = None
        self.softmax = None
        self.model = None
        self.output_model = None

        self.decode_build = False
        self.encoder_model = None
        self.decoder_model = None

    def compile(self, optimizer="adam"):
        source_input = Input(shape=(None, ), dtype="int32")
        target_input = Input(shape=(None, ), dtype="int32")

        target_decode_in = Lambda(lambda x: K.slice(
            x,
            start=[0, 0],
            size=[K.shape(target_input)[0],
                  K.shape(target_input)[1] - 1]))(target_input)
        target_decode_out = Lambda(lambda x: K.slice(
            x,
            start=[0, 1],
            size=[K.shape(target_input)[0],
                  K.shape(target_input)[1] - 1]))(target_input)

        src_mask = Lambda(lambda x: get_mask_seq2seq(x, x))(source_input)
        tar_mask = Lambda(lambda x: self.get_self_mask(x))(target_decode_in)
        encode_mask = Lambda(lambda x: get_mask_seq2seq(x[0], x[1]))(
            [target_decode_in, source_input])

        self.source_embedding = Embedding(input_dim=self.scr_dict_size,
                                          output_dim=self.model_dim)
        if self.share_embedding:
            self.target_embedding = self.source_embedding
        else:
            self.target_embedding = Embedding(input_dim=self.tar_dict_size,
                                              output_dim=self.model_dim)

        if self.use_pos_embedding:
            self.position_embedding = PositionEmbedding(mode="sum")

        src_x = self.source_embedding(source_input)
        if self.use_pos_embedding:
            src_x = self.position_embedding(src_x)

        src_x = Dropout(self.dropout)(src_x)

        self.encoder = Encode(num_layers=self.num_layers,
                              num_head=self.num_head,
                              head_dim=self.head_dim,
                              model_dim=self.model_dim,
                              inner_dim=self.inner_dim,
                              dropout=self.dropout)
        encoder_output = self.encoder(src_x, masks=src_mask)

        tar_x = self.target_embedding(target_decode_in)
        if self.use_pos_embedding:
            tar_x = self.position_embedding(tar_x)

        self.decoder = Decode(num_layers=self.num_layers,
                              num_head=self.num_head,
                              head_dim=self.head_dim,
                              model_dim=self.model_dim,
                              inner_dim=self.inner_dim,
                              dropout=self.dropout)
        decoder_output = self.decoder([tar_x, encoder_output],
                                      self_mask=tar_mask,
                                      encode_mask=encode_mask)

        self.softmax = TimeDistributed(Dense(self.tar_dict_size))

        output = self.softmax(decoder_output)

        loss = Lambda(lambda x: self._get_loss(*x))(
            [output, target_decode_out])

        self.model = Model([source_input, target_input], loss)
        self.model.add_loss([loss])
        self.model.compile(optimizer, None)

        self.model.metrics_names.append("ppl")
        self.model.metrics_tensors.append(Lambda(K.exp)(loss))
        self.model.metrics_names.append("accuracy")
        self.model.metrics_tensors.append(
            Lambda(lambda x: self._get_acc(x[0], x[1]))(
                [output, target_decode_out]))

        self.output_model = Model([source_input, target_input], output)

    @staticmethod
    def get_encode_mask(src_seq):
        return get_mask_seq2seq(src_seq, src_seq)

    @staticmethod
    def get_self_mask(tar_seq):
        self_mask1 = get_mask_seq2seq(tar_seq, tar_seq)
        self_mask2 = get_mask_self(tar_seq)
        return K.minimum(self_mask1, self_mask2)

    @staticmethod
    def _get_loss(y_pred, y_true):
        y_true = tf.cast(y_true, dtype="int32")
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true,
                                                              logits=y_pred)
        mask = tf.cast(tf.not_equal(y_true, 0), dtype="float32")
        loss = tf.reduce_sum(loss * mask, -1) / tf.reduce_sum(mask, -1)
        return tf.reduce_mean(loss)

    @staticmethod
    def _get_acc(y_pred, y_true):
        mask = tf.cast(tf.not_equal(y_true, 0), dtype="float32")
        corr = K.cast(K.equal(K.cast(y_true, dtype="int32"),
                              K.cast(K.argmax(y_pred, -1), dtype="int32")),
                      dtype="float32")
        acc = K.sum(corr * mask, -1) / K.sum(mask, -1)
        return K.mean(acc)

    def decode_fast(self, seq):
        decode_tokens = []
        target_seq = np.zeros(shape=(1, self.length_limit), dtype=np.int32)
        target_seq[0, 0] = 2
        for i in range(self.length_limit - 1):
            output = self.output_model.predict_on_batch([seq, target_seq])
            max_prob_index = np.argmax(output[0, i, :])
            max_prob_token = self.tar_token_dict[max_prob_index]
            decode_tokens.append(max_prob_token)
            if max_prob_index == 3:
                break
            target_seq[0, i + 1] = max_prob_index
        return " ".join(decode_tokens)

    def _build_encoder(self):
        source_input = Input(shape=(None, ), dtype="int32")

        src_mask = Lambda(lambda x: get_mask_seq2seq(x, x))(source_input)

        src_x = self.source_embedding(source_input)
        if self.use_pos_embedding:
            src_x = self.position_embedding(src_x)

        encoder_output = self.encoder(src_x, masks=src_mask)
        self.encoder_model = Model([source_input], encoder_output)
        self.encoder_model.compile('adam', 'mse')

    def _build_decoder(self):
        source_input = Input(shape=(None, ), dtype="int32")
        target_input = Input(shape=(None, ), dtype="int32")
        encoder_output = Input(shape=(None, self.model_dim))

        tar_mask = Lambda(lambda x: self.get_self_mask(x))(target_input)
        encode_mask = Lambda(lambda x: get_mask_seq2seq(x[0], x[1]))(
            [target_input, source_input])

        tar_x = self.target_embedding(target_input)
        if self.use_pos_embedding:
            tar_x = self.position_embedding(tar_x)

        decoder_output = self.decoder([tar_x, encoder_output],
                                      self_mask=tar_mask,
                                      encode_mask=encode_mask)
        final_output = self.softmax(decoder_output)
        self.decoder_model = Model(
            [source_input, target_input, encoder_output], final_output)
        self.decoder_model.compile('adam', 'mse')

    def _build_decode_model(self):
        self._build_encoder()
        self._build_decoder()
        self.decode_build = True

    def decode(self, seq):
        if not self.decode_build:
            self._build_decode_model()

        decode_tokens = []
        target_seq = np.zeros(shape=(1, self.length_limit), dtype=np.int32)
        target_seq[0, 0] = 2

        encoder_output = self.encoder_model.predict_on_batch([seq])
        for i in range(self.length_limit - 1):
            output = self.decoder_model.predict_on_batch(
                [seq, target_seq, encoder_output])
            max_prob_index = np.argmax(output[0, i, :])
            max_prob_token = self.tar_token_dict[max_prob_index]
            decode_tokens.append(max_prob_token)
            if max_prob_index == 3:
                break
            target_seq[0, i + 1] = max_prob_index
        return " ".join(decode_tokens)

    def beam_search(self, seq, topk=3):
        if not self.decode_build:
            self._build_decode_model()

        seq = np.repeat(seq, topk, axis=0)
        encoder_output = self.encoder_model.predict_on_batch([seq])

        final_results = []
        topk_prob = np.zeros((topk, ), dtype=np.float32)
        decode_tokens = [[] for _ in range(topk)]

        target_seq = np.zeros((topk, self.length_limit), dtype=np.int32)
        target_seq[:, 0] = 2

        last_k = 1

        for i in range(self.length_limit - 1):
            if last_k == 0 or len(final_results) > topk * 3:
                break  # stop conditions

            target_output = self.decoder_model.predict_on_batch(
                [seq, target_seq, encoder_output])
            output = np.exp(target_output[:, i, :])
            output = output / np.sum(output, axis=-1, keepdims=True)
            output = np.log(
                output +
                1e-8)  # use `log` transformation to avoid tiny probability

            candidates = []

            for k, probs in zip(range(last_k), output):
                if target_seq[k, i] == 3:
                    continue

                word_p_sort = sorted(list(enumerate(probs)),
                                     key=lambda x: x[1],
                                     reverse=True)
                for ind, wp in word_p_sort[:topk]:
                    candidates.append((k, ind, topk_prob[k] + wp))

            candidates = sorted(candidates, key=lambda x: x[-1], reverse=True)
            candidates = candidates[:topk]

            target_seq_bk = target_seq.copy()

            for new_k, cand in enumerate(candidates):
                k, ind, seq_p = cand
                target_seq[new_k] = target_seq_bk[k]
                target_seq[new_k, i + 1] = ind
                topk_prob[new_k] = seq_p
                decode_tokens.append(decode_tokens[k] +
                                     [self.tar_token_dict[ind]])
                if ind == 3:
                    final_results.append((decode_tokens[k], seq_p))

            decode_tokens = decode_tokens[topk:]
            last_k = len(decode_tokens)

        final_results = [(x, y / (len(x) + 1)) for x, y in final_results]
        final_results = sorted(final_results, key=lambda x: x[1], reverse=True)
        return final_results
Ejemplo n.º 8
0
class IWAE(object):
    def __init__(self, inpt, latent, reconstruction, likelihood, k_samples):
        # Create self.inpt, self.latent, self.reconstruction, and self.likelihood
        self.__dict__.update(locals())

        # 'Model' is a trainable keras object.
        self.model = Model(inpt, reconstruction)
        self.model.add_loss(-self.elbo())

    def set_samples(self, k):
        K.set_value(self.k_samples, k)

    def elbo(self):
        batch = K.shape(self.inpt)[0]
        # shape of flat_input is (batch, pixels)
        flat_input = K.batch_flatten(self.inpt)
        # shape of repeated_input is (batch, samples, pixels)
        repeated_input = K.repeat(flat_input, self.k_samples)
        # shape of flat_repeated_input is (batch * samples, pixels) to match the shape of self.reconstruction
        flat_repeated_input = K.reshape(repeated_input,
                                        (batch * self.k_samples, -1))

        # LL term is E_q(z|x) [ log p(x|z) ] (Note that mean over k_samples happens later)
        # shape of flat_ll is (batch * samples,).
        flat_ll = self.likelihood.log_prob(flat_repeated_input)
        # shape of self.ll is (batch, samples)
        self.ll = K.reshape(flat_ll, (batch, -1))

        # Final loss is weighted sum across k samples. More precisely, the total gradient is a
        # weighted sum of sample gradients. K.stop_gradient() is used to make the weights act on
        # gradients and not provide gradients themselves (weights are not 'learned' per se).
        # Weights have shape (batch, samples).
        weights = K.stop_gradient(self._get_weights())

        # KL term is E_q(z|x) [ log q(z|x) / p(z) ] and has shape (batch,) if analytic or
        # (batch, samples) otherwise
        try:
            # Use analytic KL if it is available, which has shape (batch,)
            self.kl = self.latent.analytic_kl()

            return K.sum(weights * self.ll, axis=-1) - self.kl
        except TypeError:
            # If analytic KL is not available, fall back on sample KL.
            self.kl = self.latent.sample_kl()

            # ELBO is mean-over-samples of (LL - KL) and has shape (batch,)
            return K.sum(weights * (self.ll - self.kl), axis=-1)

    def _get_weights(self):
        # IWAE sample weight on sample i is p(x,latent_i)/q(latent_i|x). Weights are then
        # normalized to sum to 1. First computing log-weights is more numerically stable.
        log_p = self.latent.prior.log_prob(self.latent.samples) + self.ll
        log_q = self.latent.log_prob(self.latent.samples)
        log_weights = log_p - log_q
        # Pre-normalize results in log space by subtracting logsumexp (which is dividing by sum in
        # probability space). This keeps results stable so that the following call to exp() is
        # given values in a reasonable range. Results may not sum to exactly 1, though, due to
        # floating point precision.
        log_weights -= K.logsumexp(log_weights, axis=-1, keepdims=True)
        # Get out of log space and normalize a second time since logsumexp is not perfect.
        weights_unnormalized = K.exp(log_weights)
        return weights_unnormalized / K.sum(
            weights_unnormalized, axis=-1, keepdims=True)