Example #1
0
  def _create_network(self):
    logF, loss_grads = self._create_loss()
    self._create_train_op(loss_grads)

    # Create IWAE lower bound for evaluation
    self.logF = self._reshape(logF)
    self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
                               tf.log(tf.to_float(self.n_samples)))
Example #2
0
    def _create_network(self):
        logF, loss_grads = self._create_loss()
        self._create_train_op(loss_grads)

        # Create IWAE lower bound for evaluation
        self.logF = self._reshape(logF)
        self.iwae = tf.reduce_mean(
            U.logSumExp(self.logF, axis=1) -
            tf.log(tf.to_float(self.n_samples)))
Example #3
0
    def assertEqualMarginals(self, graph, all_sequences, sent_likelihood):
        """
		Check factor/variable marginals are approximately equal
		to marginals obtained from brute force inference
		"""

        # Check variable marginals
        threshold = 0.01
        eq = True

        denom = -float("inf")
        maxDiff = -float("inf")

        for s, sequence in enumerate(all_sequences):
            denom = utils.logSumExp(sent_likelihood[s], denom)

        # Iterate over all timesteps
        for t in range(graph.T):
            for tag in self.model.uniqueTags:
                tagBeliefs = graph.getVarByTimestepnTag(
                    t, tag.idx).belief.cpu().numpy()
                for labelIdx in range(tag.size()):
                    num = -float("inf")
                    for s, sequence in enumerate(all_sequences):
                        if sequence[t][tag.idx] == labelIdx:
                            num = utils.logSumExp(sent_likelihood[s], num)

                    # Check difference
                    # maxDiff = max(maxDiff, np.max(np.abs(tagBeliefs[labelIdx]- np.exp(num-denom))))
                    tagLogProb = np.exp(num - denom)
                    maxDiff = max(
                        maxDiff,
                        np.max(
                            np.abs(np.exp(tagBeliefs[labelIdx]) - tagLogProb)),
                    )
                    if maxDiff > threshold:
                        eq = False

        if not eq:
            print("Marginals not equal. Max difference of %f" % maxDiff)
        else:
            print("Passed unit test!")

        sys.exit(0)
Example #4
0
  def _create_network(self):
    logF, loss_grads, variance_objective, variance_objective_grad = self._create_loss()
    eta_grads = (self.optimizer_class.compute_gradients(variance_objective,
                                                        var_list=tf.get_collection('CV'))
                 + [(variance_objective_grad, self.pre_temperature_variable)])
    self._create_train_op(loss_grads, eta_grads)

    # Create IWAE lower bound for evaluation
    self.logF = self._reshape(logF)
    self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
                               tf.log(tf.to_float(self.n_samples)))
Example #5
0
  def _create_network(self):
    logF, loss_grads, variance_objective, variance_objective_grad = self._create_loss()
    eta_grads = (self.optimizer_class.compute_gradients(variance_objective,
                                                        var_list=tf.get_collection('CV'))
                 + [(variance_objective_grad, self.pre_temperature_variable)])
    self._create_train_op(loss_grads, eta_grads)

    # Create IWAE lower bound for evaluation
    self.logF = self._reshape(logF)
    self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
                               tf.log(tf.to_float(self.n_samples)))
Example #6
0
  def _create_network(self):
    logF = self._create_loss()
    self.optimizerLoss = tf.reduce_mean(self.optimizerLoss)

    # Setup optimizer
    grads_and_vars = self.optimizer_class.compute_gradients(self.optimizerLoss)
    self._create_train_op(grads_and_vars)

    # Create IWAE lower bound for evaluation
    self.logF = self._reshape(logF)
    self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
                               tf.log(tf.to_float(self.n_samples)))
Example #7
0
  def _create_network(self):
    logF = self._create_loss()
    self.optimizerLoss = tf.reduce_mean(self.optimizerLoss)

    # Setup optimizer
    grads_and_vars = self.optimizer_class.compute_gradients(self.optimizerLoss)
    self._create_train_op(grads_and_vars)

    # Create IWAE lower bound for evaluation
    self.logF = self._reshape(logF)
    self.iwae = tf.reduce_mean(U.logSumExp(self.logF, axis=1) -
                               tf.log(tf.to_float(self.n_samples)))
Example #8
0
    def belief_propogation_log(self,
                               gold_tags,
                               lang,
                               sentLen,
                               batch_lstm_feats,
                               test=False):

        # fwd messages, then bwd messages for each tag => ! O(n^2)

        # start_time = time.time()

        threshold = 0.05
        maxIters = 50
        batch_size = len(batch_lstm_feats)
        if self.model_type == "specific":
            langIdx = self.langs.index(lang)

        # Initialize factor graph, add vars and factors
        print("Creating Factor Graph...")
        graph = FactorGraph(sentLen, batch_size, self.gpu)

        # Add variables to graph
        for tag in self.uniqueTags:
            for t in range(sentLen):
                label = None
                graph.addVariable(tag, label, t)

        if not self.no_pairwise:
            # Add pairwise factors to graph
            kind = "pair"
            for tag1 in self.uniqueTags:
                for tag2 in self.uniqueTags:
                    if tag1 != tag2 and tag1.idx < tag2.idx:
                        for t in range(sentLen):
                            var1 = graph.getVarByTimestepnTag(t, tag1.idx)
                            var2 = graph.getVarByTimestepnTag(t, tag2.idx)
                            graph.addFactor(kind, var1, var2)

            # Retrieve pairwise weights
            pairwise_weights_np = []
            for i in range(len(self.pairs)):
                pairwise_weights_np.append(
                    self.pairwise_weights[i].cpu().data.numpy())

            if self.model_type == "specific":
                for i in range(len(self.pairs)):
                    pairwise_weights_np[i] = utils.logSumExp(
                        pairwise_weights_np[i], self.lang_pairwise_weights[i]
                        [langIdx].cpu().data.numpy())

        if not self.no_transitions:
            # Add transition factors to graph
            kind = "trans"
            for tag in self.uniqueTags:
                for t in range(sentLen - 1):
                    var1 = graph.getVarByTimestepnTag(t, tag.idx)
                    var2 = graph.getVarByTimestepnTag(t + 1, tag.idx)
                    graph.addFactor(kind, var1, var2)

            transition_weights_np = {}
            for tag in self.uniqueTags:
                transition_weights_np[tag.idx] = self.transition_weights[
                    tag.idx].cpu().data.numpy()

            if self.model_type == "specific":
                for tag in self.uniqueTags:
                    transition_weights_np[tag.idx] = utils.logSumExp(
                        transition_weights_np[tag.idx],
                        self.lang_transition_weights[
                            tag.idx][langIdx].cpu().data.numpy())

        kind = "lstm"
        for tag in self.uniqueTags:
            for t in range(sentLen):
                var = graph.getVarByTimestepnTag(t, tag.idx)
                graph.addFactor(kind, var, "LSTMVar")

        # Initialize messages
        messages = Messages(graph, batch_size)

        # Add LSTM unary factor message to each variable
        for tag in self.uniqueTags:
            for t in range(sentLen):
                lstm_vecs = []
                var = graph.getVarByTimestepnTag(t, tag.idx)
                lstm_factor = graph.getFactorByVars(var, "LSTMVar")
                cur_tag_lstm_weights = self.lstm_weights[tag.idx]

                for batchIdx in range(batch_size):
                    lstm_feats = batch_lstm_feats[batchIdx]
                    cur_lstm_feats = lstm_feats[t]
                    cur_tag_lstm_feats = cur_lstm_feats[
                        self.tag_offsets[tag.name]:self.tag_offsets[tag.name] +
                        tag.size()]
                    lstm_vec = torch.unsqueeze(
                        cur_tag_lstm_weights + cur_tag_lstm_feats, 0)
                    lstm_vec = utils.logNormalizeTensor(lstm_vec).squeeze(
                        dim=0)
                    lstm_vecs.append(lstm_vec.cpu().data.numpy())

                messages.updateMessage(lstm_factor, var, np.array(lstm_vecs))

        iter = 0
        while iter < maxIters:
            print("[BP iteration %d]" % iter, end=" ")
            maxVal = [-float("inf")] * batch_size
            for tag in self.uniqueTags:
                var_list = graph.getVarsByTag(tag.idx)

                # FORWARD

                for t in range(sentLen):

                    var = var_list[t]

                    # Get pairwise potentials

                    factor_list = graph.getFactorByVars(var)
                    factor_sum = np.zeros((batch_size, var.tag.size()))

                    # Maintaining factor sum improves efficiency
                    for factor_mult in factor_list:
                        factor_sum += messages.getMessage(factor_mult,
                                                          var).value

                    for factor in factor_list:
                        if factor.kind == "pair":
                            var2 = factor.getOtherVar(var)

                            # variable2factor

                            message = np.zeros((batch_size, var.tag.size()))
                            message = factor_sum - messages.getMessage(
                                factor, var).value
                            message = utils.logNormalize(message)
                            curVal = messages.getMessage(var, factor).value

                            # From (Sutton, 2012)
                            maxVal = np.maximum(
                                maxVal, np.amax(np.abs(curVal - message), 1))
                            messages.updateMessage(var, factor, message)

                            # factor2variable
                            if var2.tag.idx < var.tag.idx:
                                pairwise_idx = self.pairs.index(
                                    (var2.tag.idx, var.tag.idx))
                                transpose = False
                            else:
                                pairwise_idx = self.pairs.index(
                                    (var.tag.idx, var2.tag.idx))
                                transpose = True

                            cur_pairwise_weights = pairwise_weights_np[
                                pairwise_idx]

                            if transpose:

                                pairwise_pot = utils.logDot(
                                    cur_pairwise_weights,
                                    messages.getMessage(var2, factor).value,
                                    redAxis=1)
                            else:
                                pairwise_pot = utils.logDot(
                                    messages.getMessage(var2, factor).value,
                                    cur_pairwise_weights,
                                    redAxis=0)

                            pairwise_pot = utils.logNormalize(pairwise_pot)
                            curVal = messages.getMessage(factor, var).value
                            maxVal = np.maximum(
                                maxVal,
                                np.amax(np.abs(curVal - pairwise_pot), 1))
                            messages.updateMessage(factor, var, pairwise_pot)
                            factor_sum += pairwise_pot - curVal

                    if not self.no_transitions:

                        cur_tag_weights = transition_weights_np[tag.idx]

                        # Get transition potential
                        if t != sentLen - 1:

                            var2 = graph.getVarByTimestepnTag(t + 1, tag.idx)
                            trans_factor = graph.getFactorByVars(var, var2)

                            # Variable2Factor Message

                            message = np.zeros((batch_size, var.tag.size()))
                            message = factor_sum - messages.getMessage(
                                trans_factor, var).value

                            # for factor_mult in factor_list:
                            # 	if factor_mult!=trans_factor:
                            # 		message += messages.getMessage(factor_mult, var).value

                            message = utils.logNormalize(message)
                            curVal = messages.getMessage(var,
                                                         trans_factor).value
                            maxVal = np.maximum(
                                maxVal, np.amax(np.abs(curVal - message), 1))
                            messages.updateMessage(var, trans_factor, message)
                            # Factor2Variable Message

                            transition_pot = utils.logDot(messages.getMessage(
                                var, trans_factor).value,
                                                          cur_tag_weights,
                                                          redAxis=0)

                            transition_pot = utils.logNormalize(transition_pot)
                            curVal = messages.getMessage(trans_factor,
                                                         var2).value
                            maxVal = np.maximum(
                                maxVal,
                                np.amax(np.abs(curVal - transition_pot), 1))
                            messages.updateMessage(trans_factor, var2,
                                                   transition_pot)

                # BACKWARD
                if not self.no_transitions:

                    for t in range(sentLen - 1, 0, -1):

                        var = var_list[t]
                        factor_list = graph.getFactorByVars(var)

                        # Variable2Factor Message

                        var2 = graph.getVarByTimestepnTag(t - 1, tag.idx)
                        trans_factor = graph.getFactorByVars(var, var2)

                        message = np.zeros((batch_size, var.tag.size()))

                        for i, factor_mult in enumerate(factor_list):
                            if factor_mult != trans_factor:
                                message += messages.getMessage(
                                    factor_mult, var).value

                        message = utils.logNormalize(message)
                        curVal = messages.getMessage(var, trans_factor).value
                        maxVal = np.maximum(
                            maxVal, np.amax(np.abs(curVal - message), 1))
                        messages.updateMessage(var, trans_factor, message)
                        transition_pot = utils.logDot(cur_tag_weights,
                                                      messages.getMessage(
                                                          var,
                                                          trans_factor).value,
                                                      redAxis=1)

                        transition_pot = utils.logNormalize(transition_pot)
                        curVal = messages.getMessage(trans_factor, var2).value
                        maxVal = np.maximum(
                            maxVal, np.amax(np.abs(curVal - transition_pot),
                                            1))
                        messages.updateMessage(trans_factor, var2,
                                               transition_pot)

            iter += 1
            print("Max Res Value: %f" % max(maxVal))
            if max(maxVal) <= threshold:
                print("Converged in %d iterations" % (iter))
                break
            if iter == 1000:
                print("Diverging :( Finished 1000 iterations.")
                return None

        # Calculate belief values and marginals

        # Variable beliefs
        for tag in self.uniqueTags:
            for t in range(sentLen):

                var = graph.getVarByTimestepnTag(t, tag.idx)

                factor_list = graph.getFactorByVars(var)
                for factor in factor_list:
                    factorMsg = Variable(
                        torch.FloatTensor(
                            messages.getMessage(factor, var).value))
                    if self.gpu:
                        factorMsg = factorMsg.cuda()
                    var.belief = var.belief + factorMsg

                # Normalize
                var.belief = utils.logNormalizeTensor(var.belief)

        # Factor beliefs
        for factor in graph.iterFactors():
            var1, var2 = graph.getVarsByFactor(factor)
            if factor.kind == "trans":
                factor.belief = self.transition_weights[var1.tag.idx]
                if self.model_type == "specific":
                    factor.belief = utils.logSumExpTensors(
                        factor.belief,
                        self.lang_transition_weights[var1.tag.idx][langIdx])
            elif factor.kind == "pair":
                pairwise_idx = self.pairs.index((var1.tag.idx, var2.tag.idx))
                factor.belief = self.pairwise_weights[pairwise_idx]
                if self.model_type == "specific":
                    factor.belief = utils.logSumExpTensors(
                        factor.belief,
                        self.lang_pairwise_weights[pairwise_idx][langIdx])
            else:
                continue

            factor.belief = factor.belief.view(1, factor.belief.size(0),
                                               -1).expand(batch_size, -1, -1)

            msg1 = torch.FloatTensor(messages.getMessage(var1, factor).value)
            msg2 = torch.FloatTensor(messages.getMessage(var2, factor).value)
            if self.gpu:
                msg1 = msg1.cuda()
                msg2 = msg2.cuda()

            factor.belief = Variable(
                msg1.view(batch_size, -1, 1).expand(
                    -1, -1, var2.tag.size())) + factor.belief
            factor.belief = Variable(
                msg2.view(batch_size, 1, -1).expand(-1, var1.tag.size(),
                                                    -1)) + factor.belief
            factor.belief = utils.logNormalizeTensor(factor.belief)

        # Calculate likelihood
        # likelihood = self.calc_likelihood(graph, gold_tags)
        # print("--- %s seconds ---" % (time.time() - start_time))
        return graph, max(maxVal)
Example #9
0
z_params = tf.reshape(net, [-1, N, K])
tau = tf.constant(0.5)
q_z = RelaxedOneHotCategorical(tau, z_params)
z = tf.to_float(q_z.sample())
net = slim.flatten(z)
net = slim.stack(net, slim.fully_connected, [D])
logits_x = slim.fully_connected(net, 784, activation_fn=None)
p_x = Bernoulli(logits=logits_x)

p_z = OneHotCategorical(tf.ones_like(z_params) * 1.0 / K)

# loss
logP = tf.reduce_sum(p_x.log_prob(x_), 1) - tf.reduce_sum(
    q_z.log_prob(z), 1) + tf.reduce_sum(p_z.log_prob(z), 1)
logF = tf.transpose(tf.reshape(logP, [sample_n, -1]))
iwae1 = U.logSumExp(logF, axis=1) - tf.log(tf.to_float(sample_n))
iwae2 = -U.logSumExp(-logF, axis=1) + tf.log(tf.to_float(sample_n))

loss_vae = -tf.reduce_mean(logP)
loss_iwae1 = -tf.reduce_mean(iwae1)
loss_iwae2 = -tf.reduce_mean(iwae2)

# evaluation
loglikelihood = tf.reduce_mean(iwae1)

#### training optimizer
train_op_vae = tf.train.AdamOptimizer(learning_rate=3e-4).minimize(loss_vae)
train_op_iwae1 = tf.train.AdamOptimizer(
    learning_rate=3e-4).minimize(loss_iwae1)
train_op_iwae2 = tf.train.AdamOptimizer(
    learning_rate=3e-4).minimize(loss_iwae2)