Ejemplo n.º 1
0
 def get_conditional_probability(self, t1_min_embed, t1_max_embed,
                                 t2_min_embed, t2_max_embed):
     _, _, meet_min, meet_max, disjoint = unit_cube.calc_join_and_meet(
         t1_min_embed, t1_max_embed, t2_min_embed, t2_max_embed)
     nested = unit_cube.calc_nested(t1_min_embed, t1_max_embed,
                                    t2_min_embed, t2_max_embed,
                                    self.embed_dim)
     """get conditional probabilities"""
     overlap_volume = tf.reduce_prod(tf.nn.softplus(
         (meet_max - meet_min) / self.temperature) * self.temperature,
                                     axis=-1)
     rhs_volume = tf.reduce_prod(tf.nn.softplus(
         (t1_max_embed - t1_min_embed) / self.temperature) *
                                 self.temperature,
                                 axis=-1)
     conditional_logits = tf.log(overlap_volume +
                                 1e-10) - tf.log(rhs_volume + 1e-10)
     return conditional_logits, meet_min, meet_max, disjoint, nested, overlap_volume, rhs_volume
Ejemplo n.º 2
0
    def __init__(self, data, placeholder, FLAGS):
        self.optimizer = FLAGS.optimizer
        self.opti_epsilon = FLAGS.epsilon
        self.lr = FLAGS.learning_rate
        self.vocab_size = data.vocab_size
        self.measure = FLAGS.measure
        self.embed_dim = FLAGS.embed_dim
        self.batch_size = FLAGS.batch_size
        self.rel_size = FLAGS.rel_size
        self.tuple_model = FLAGS.tuple_model
        self.init_embedding = FLAGS.init_embedding
        self.rang = tf.range(0, FLAGS.batch_size, 1)
        # LSTM Params
        self.term = FLAGS.term
        self.hidden_dim = FLAGS.hidden_dim
        self.peephole = FLAGS.peephole
        self.freeze_grad = FLAGS.freeze_grad

        self.t1x = placeholder['t1_idx_placeholder']
        self.t1mask = placeholder['t1_msk_placeholder']
        self.t1length = placeholder['t1_length_placeholder']
        self.t2x = placeholder['t2_idx_placeholder']
        self.t2mask = placeholder['t2_msk_placeholder']
        self.t2length = placeholder['t2_length_placeholder']
        self.rel = placeholder['rel_placeholder']
        self.relmsk = placeholder['rel_msk_placeholder']
        self.label = placeholder['label_placeholder']
        """Initiate box embeddings"""
        self.min_embed, self.delta_embed, self.rel_embed = self.init_word_embedding(
            data)
        """get unit box representation for both term, no matter they are phrases or words"""
        if self.term:
            # if the terms are phrases, need to use either word average or lstm to compose the word embedding
            # Then transform them into unit cube.
            raw_t1_min_embed, raw_t1_delta_embed, raw_t2_min_embed, raw_t2_delta_embed = self.get_term_word_embedding(
                self.t1x, self.t1mask, self.t1length, self.t2x, self.t2mask,
                self.t2length, False)
            self.t1_min_embed, self.t1_max_embed, self.t2_min_embed, self.t2_max_embed = self.transform_cube(
                raw_t1_min_embed, raw_t1_delta_embed, raw_t2_min_embed,
                raw_t2_delta_embed)

        else:
            self.t1_min_embed, self.t1_max_embed, self.t2_min_embed, self.t2_max_embed = self.get_word_embedding(
                self.t1x, self.t2x)
        """get negative example unit box representation, if it's randomly generated during training."""
        if FLAGS.neg == 'uniform':
            # broken now, need to have different graph during train and evaluation
            nt1_min_embed, nt1_max_embed, nt2_min_embed, nt2_max_embed = self.generate_neg(
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed)
            self.t1_min_embed = tf.concat([self.t1_min_embed, nt1_min_embed],
                                          axis=0)
            self.t1_max_embed = tf.concat([self.t1_max_embed, nt1_max_embed],
                                          axis=0)
            self.t2_min_embed = tf.concat([self.t2_min_embed, nt2_min_embed],
                                          axis=0)
            self.t2_max_embed = tf.concat([self.t2_max_embed, nt2_max_embed],
                                          axis=0)
            self.label = tf.concat([self.label, tf.zeros_like(self.label)], 0)
        """calculate box stats, join, meet and overlap condition"""
        self.join_min, self.join_max, self.meet_min, self.meet_max, self.disjoint = unit_cube.calc_join_and_meet(
            self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
            self.t2_max_embed)
        """calculate -log(p(term2 | term1)) if overlap, surrogate function if not overlap"""
        # two surrogate function choice. lambda_batch_log_upper_bound or lambda_batch_disjoint_box
        if FLAGS.surrogate_bound:
            surrogate_func = unit_cube.lambda_batch_log_upper_bound
        else:
            surrogate_func = unit_cube.lambda_batch_disjoint_box

        train_pos_prob = tf_utils.slicing_where(
            condition=self.disjoint,
            full_input=tf.tuple([
                self.join_min, self.join_max, self.meet_min, self.meet_max,
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed
            ]),
            true_branch=lambda x: surrogate_func(*x),
            false_branch=lambda x: unit_cube.lambda_batch_log_prob(*x))
        # train_pos_prob = tf.Print(train_pos_prob, [tf.reduce_sum(tf.cast(tf.logical_and(
        #     self.disjoint, tf.logical_not(tf.cast(self.label, tf.bool))), tf.float32))],
        #                           'neg disjoint value', summarize=3)
        # train_pos_prob = tf.Print(train_pos_prob, [tf.reduce_sum(tf.cast(tf.logical_and(
        #     self.disjoint, tf.cast(self.label, tf.bool)), tf.float32))],
        #                           'pos disjoint value', summarize=3)
        """calculate -log(1-p(term2 | term1)) if overlap, 0 if not overlap"""
        train_neg_prob = tf_utils.slicing_where(
            condition=self.disjoint,
            full_input=([
                self.join_min, self.join_max, self.meet_min, self.meet_max,
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed
            ]),
            true_branch=lambda x: unit_cube.lambda_zero(*x),
            false_branch=lambda x: unit_cube.lambda_batch_log_1minus_prob(*x))
        """calculate negative log prob when evaluating pairs. The lower, the better"""
        # when return hierarchical error, we return the negative log probability, the lower, the probability higher
        # if two things are disjoint, we return -tf.log(1e-8).
        self.eval_prob = tf_utils.slicing_where(
            condition=self.disjoint,
            full_input=[
                self.join_min, self.join_max, self.meet_min, self.meet_max,
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed
            ],
            true_branch=lambda x: unit_cube.lambda_hierarchical_error_upper(*x
                                                                            ),
            false_branch=lambda x: unit_cube.lambda_batch_log_prob(*x))
        """model marg prob loss"""
        if FLAGS.w2 > 0.0:
            self.marg_prob = tf.constant(data.margina_prob)
            kl_difference = unit_cube.calc_marginal_prob(
                self.marg_prob, self.min_embed, self.delta_embed)
            kl_difference = tf.reshape(kl_difference, [-1]) / self.vocab_size
            self.marg_loss = FLAGS.w2 * (tf.reduce_sum(kl_difference))
        else:
            self.marg_loss = 0.0
        """model cond prob loss"""
        self.pos = FLAGS.w1 * tf.multiply(train_pos_prob, self.label)
        self.neg = FLAGS.w1 * tf.multiply(train_neg_prob, (1 - self.label))
        self.pos_disjoint = tf.logical_and(tf.cast(self.label, tf.bool),
                                           self.disjoint)
        self.pos_overlap = tf.logical_and(tf.cast(self.label, tf.bool),
                                          tf.logical_not(self.disjoint))
        self.neg_disjoint = tf.logical_and(
            tf.logical_not(tf.cast(self.label, tf.bool)), self.disjoint)
        self.neg_overlap = tf.logical_and(
            tf.logical_not(tf.cast(self.label, tf.bool)),
            tf.logical_not(self.disjoint))
        self.pos_disjoint.set_shape([None])
        self.neg_disjoint.set_shape([None])
        self.pos_overlap.set_shape([None])
        self.neg_overlap.set_shape([None])
        # self.pos = tf.Print(self.pos, [tf.reduce_mean(tf.boolean_mask(self.pos, self.pos_disjoint))], 'pos disjoint loss')
        # self.pos = tf.Print(self.pos, [tf.reduce_mean(tf.boolean_mask(self.pos, self.pos_overlap))], 'pos overlap loss')
        # self.neg = tf.Print(self.neg, [tf.reduce_mean(tf.boolean_mask(self.neg, self.neg_disjoint))], 'neg disjoint loss')
        # self.neg = tf.Print(self.neg, [tf.reduce_mean(tf.boolean_mask(self.neg, self.neg_overlap))], 'neg overlap loss')
        # self.pos = tf.Print(self.pos, [tf.reduce_sum(self.pos)], 'pos loss')
        # self.neg = tf.Print(self.neg, [tf.reduce_sum(self.neg)], 'neg loss')
        self.cond_loss = tf.reduce_sum(self.pos) / (self.batch_size / 2) + \
                         tf.reduce_sum(self.neg) / (self.batch_size / 2)
        """model regurlization: make box to be poe-ish"""
        self.regularization = FLAGS.r1 * tf.reduce_sum(
            tf.abs(1 - self.min_embed - self.delta_embed)) / self.vocab_size
        """model final loss"""
        # self.cond_loss = tf.Print(self.cond_loss, [self.pos, self.neg])
        self.loss = self.cond_loss + self.marg_loss + self.regularization
Ejemplo n.º 3
0
    def __init__(self, data, placeholder, FLAGS):
        self.optimizer = FLAGS.optimizer
        self.opti_epsilon = FLAGS.epsilon
        self.lr = FLAGS.learning_rate
        self.vocab_size = data.vocab_size
        self.measure = FLAGS.measure
        self.embed_dim = FLAGS.embed_dim
        self.batch_size = FLAGS.batch_size
        self.rel_size = FLAGS.rel_size
        self.tuple_model = FLAGS.tuple_model
        self.init_embedding = FLAGS.init_embedding
        self.rang = tf.range(0, FLAGS.batch_size, 1)
        # LSTM Params
        self.term = FLAGS.term
        self.hidden_dim = FLAGS.hidden_dim
        self.peephole = FLAGS.peephole
        self.freeze_grad = FLAGS.freeze_grad
        self.temperature = tf.Variable(FLAGS.temperature, trainable=False)
        self.decay_rate = FLAGS.decay_rate

        self.t1x = placeholder['t1_idx_placeholder']
        self.t1mask = placeholder['t1_msk_placeholder']
        self.t1length = placeholder['t1_length_placeholder']
        self.t2x = placeholder['t2_idx_placeholder']
        self.t2mask = placeholder['t2_msk_placeholder']
        self.t2length = placeholder['t2_length_placeholder']
        self.rel = placeholder['rel_placeholder']
        self.relmsk = placeholder['rel_msk_placeholder']
        self.label = placeholder['label_placeholder']
        """Initiate box embeddings"""
        self.min_embed, self.delta_embed = self.init_word_embedding(data)
        self.projector = unit_cube.MinMaxHyperCubeProjectorDeltaParam(
            self.min_embed, self.delta_embed, 0.0, 1e-10)
        self.project_op = self.projector.project_op
        """get unit box representation for both term, no matter they are phrases or words"""
        if self.term:
            # if the terms are phrases, need to use either word average or lstm to compose the word embedding
            # Then transform them into unit cube.
            raw_t1_min_embed, raw_t1_delta_embed, raw_t2_min_embed, raw_t2_delta_embed = self.get_term_word_embedding(
                self.t1x, self.t1mask, self.t1length, self.t2x, self.t2mask,
                self.t2length, False)
            self.t1_min_embed, self.t1_max_embed, self.t2_min_embed, self.t2_max_embed = self.transform_cube(
                raw_t1_min_embed, raw_t1_delta_embed, raw_t2_min_embed,
                raw_t2_delta_embed)

        else:
            self.t1_min_embed, self.t1_max_embed, self.t2_min_embed, self.t2_max_embed = self.get_word_embedding(
                self.t1x, self.t2x)
        """get negative example unit box representation, if it's randomly generated during training."""
        if FLAGS.neg == 'uniform':
            neg_num = 5
            self.nt1x = tf.random_uniform([self.batch_size * neg_num, 1],
                                          0,
                                          self.vocab_size,
                                          dtype=tf.int32)
            self.nt2x = tf.random_uniform([self.batch_size * neg_num, 1],
                                          0,
                                          self.vocab_size,
                                          dtype=tf.int32)
            self.nt1_min_embed, self.nt1_max_embed, self.nt2_min_embed, self.nt2_max_embed = self.get_word_embedding(
                self.nt1x, self.nt2x)
            # combine the original word embedding with the new embeddings.
            self.nt1_min_embed = tf.concat(
                [tf.tile(self.t1_min_embed, [neg_num, 1]), self.nt1_min_embed],
                axis=0)
            self.nt1_max_embed = tf.concat(
                [tf.tile(self.t1_max_embed, [neg_num, 1]), self.nt1_max_embed],
                axis=0)
            self.nt2_min_embed = tf.concat(
                [self.nt2_min_embed,
                 tf.tile(self.t2_min_embed, [neg_num, 1])],
                axis=0)
            self.nt2_max_embed = tf.concat(
                [self.nt2_max_embed,
                 tf.tile(self.t2_max_embed, [neg_num, 1])],
                axis=0)
            self.label = tf.concat(
                [self.label,
                 tf.zeros([self.batch_size * neg_num * 2])], 0)
            self.t1_uniform_min_embed = tf.concat(
                [self.t1_min_embed, self.nt1_min_embed], axis=0)
            self.t1_uniform_max_embed = tf.concat(
                [self.t1_max_embed, self.nt1_max_embed], axis=0)
            self.t2_uniform_min_embed = tf.concat(
                [self.t2_min_embed, self.nt2_min_embed], axis=0)
            self.t2_uniform_max_embed = tf.concat(
                [self.t2_max_embed, self.nt2_max_embed], axis=0)
            """calculate box stats, join, meet and overlap condition"""
            self.join_min, self.join_max, self.meet_min, self.meet_max, self.disjoint = unit_cube.calc_join_and_meet(
                self.t1_uniform_min_embed, self.t1_uniform_max_embed,
                self.t2_uniform_min_embed, self.t2_uniform_max_embed)
            self.nested = unit_cube.calc_nested(self.t1_uniform_min_embed,
                                                self.t1_uniform_max_embed,
                                                self.t2_uniform_min_embed,
                                                self.t2_uniform_max_embed,
                                                self.embed_dim)
            """calculate -log(p(term2 | term1)) if overlap, surrogate function if not overlap"""
            # two surrogate function choice. lambda_batch_log_upper_bound or lambda_batch_disjoint_box
            if FLAGS.surrogate_bound:
                surrogate_func = unit_cube.lambda_batch_log_upper_bound
            else:
                surrogate_func = unit_cube.lambda_batch_disjoint_box
            """tf.where"""
            pos_tensor1 = surrogate_func(self.join_min, self.join_max,
                                         self.meet_min, self.meet_max,
                                         self.t1_uniform_min_embed,
                                         self.t1_uniform_max_embed,
                                         self.t2_uniform_min_embed,
                                         self.t2_uniform_max_embed)
            pos_tensor2 = unit_cube.lambda_batch_log_prob(
                self.t1_uniform_min_embed, self.t1_uniform_max_embed,
                self.t2_uniform_min_embed, self.t2_uniform_max_embed)
            pos_tensor1 = tf.multiply(pos_tensor1,
                                      tf.cast(self.disjoint, tf.float32))
            pos_tensor2 = tf.multiply(
                pos_tensor2, tf.cast(tf.logical_not(self.disjoint),
                                     tf.float32))
            # pos_tensor1 = tf.Print(pos_tensor1, [pos_tensor1, pos_tensor2], 'pos_tensor1')
            train_pos_prob = pos_tensor1 + pos_tensor2
            """slicing where"""
            # train_pos_prob = tf_utils.slicing_where(condition=self.disjoint,
            #                                         full_input=tf.tuple([self.join_min, self.join_max, self.meet_min, self.meet_max,
            #                                                      self.t1_uniform_min_embed, self.t1_uniform_max_embed,
            #                                                      self.t2_uniform_min_embed, self.t2_uniform_max_embed]),
            #                                         true_branch=lambda x: surrogate_func(*x),
            #                                         false_branch=lambda x: unit_cube.lambda_batch_log_prob_emgerncy(*x))
            """tf.print"""
            # train_pos_prob = tf.Print(train_pos_prob, [tf.reduce_sum(tf.cast(tf.logical_and(
            #     self.disjoint, tf.logical_not(tf.cast(self.label, tf.bool))), tf.float32)), self.disjoint],
            #                           'neg disjoint value', summarize=3)
            # train_pos_prob = tf.Print(train_pos_prob, [tf.reduce_sum(tf.cast(tf.logical_and(
            #     self.disjoint, tf.cast(self.label, tf.bool)), tf.float32))],
            #                           'pos disjoint value', summarize=3)
            """calculate -log(1-p(term2 | term1)) if overlap, 0 if not overlap"""
            neg_tensor1 = unit_cube.lambda_zero(self.join_min, self.join_max,
                                                self.meet_min, self.meet_max,
                                                self.t1_uniform_min_embed,
                                                self.t1_uniform_max_embed,
                                                self.t2_uniform_min_embed,
                                                self.t2_uniform_max_embed)
            neg_tensor2 = unit_cube.lambda_batch_log_1minus_prob(
                self.join_min, self.join_max, self.meet_min, self.meet_max,
                self.t1_uniform_min_embed, self.t1_uniform_max_embed,
                self.t2_uniform_min_embed, self.t2_uniform_max_embed)
            neg_tensor1 = tf.multiply(neg_tensor1,
                                      tf.cast(self.disjoint, tf.float32))
            neg_tensor2 = tf.multiply(
                neg_tensor2, tf.cast(tf.logical_not(self.disjoint),
                                     tf.float32))
            train_neg_prob = neg_tensor1 + neg_tensor2

        else:
            """calculate box stats, join, meet and overlap condition"""
            self.join_min, self.join_max, self.meet_min, self.meet_max, self.disjoint = unit_cube.calc_join_and_meet(
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed)
            self.nested = unit_cube.calc_nested(self.t1_min_embed,
                                                self.t1_max_embed,
                                                self.t2_min_embed,
                                                self.t2_max_embed,
                                                self.embed_dim)
            """calculate -log(p(term2 | term1)) if overlap, surrogate function if not overlap"""
            # two surrogate function choice. lambda_batch_log_upper_bound or lambda_batch_disjoint_box
            if FLAGS.surrogate_bound:
                surrogate_func = unit_cube.lambda_batch_log_upper_bound
            else:
                surrogate_func = unit_cube.lambda_batch_disjoint_box
            """tf.where"""
            pos_tensor1 = 500 * surrogate_func(
                self.join_min, self.join_max, self.meet_min, self.meet_max,
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed)
            pos_tensor2 = unit_cube.lambda_batch_log_prob(
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed)
            pos_tensor1 = tf.multiply(pos_tensor1,
                                      tf.cast(self.disjoint, tf.float32))
            pos_tensor2 = tf.multiply(
                pos_tensor2, tf.cast(tf.logical_not(self.disjoint),
                                     tf.float32))
            # pos_tensor1 = tf.Print(pos_tensor1, [pos_tensor1, pos_tensor2], 'pos_tensor1')
            train_pos_prob = pos_tensor1 + pos_tensor2
            """slicing where"""
            # train_pos_prob = tf_utils.slicing_where(condition=self.disjoint,
            #                                         full_input=tf.tuple([self.join_min, self.join_max, self.meet_min, self.meet_max,
            #                                                      self.t1_min_embed, self.t1_max_embed,
            #                                                      self.t2_min_embed, self.t2_max_embed]),
            #                                         true_branch=lambda x: surrogate_func(*x),
            #                                         false_branch=lambda x: unit_cube.lambda_batch_log_prob(*x))
            """tf.print"""
            # train_pos_prob = tf.Print(train_pos_prob, [tf.reduce_sum(tf.cast(tf.logical_and(
            #     self.disjoint, tf.logical_not(tf.cast(self.label, tf.bool))), tf.float32)), self.disjoint],
            #                           'neg disjoint value', summarize=3)
            # train_pos_prob = tf.Print(train_pos_prob, [tf.reduce_sum(tf.cast(tf.logical_and(
            #     self.disjoint, tf.cast(self.label, tf.bool)), tf.float32))],
            #                           'pos disjoint value', summarize=3)
            """calculate -log(1-p(term2 | term1)) if overlap, 0 if not overlap"""
            neg_tensor1 = unit_cube.lambda_zero(self.join_min, self.join_max,
                                                self.meet_min, self.meet_max,
                                                self.t1_min_embed,
                                                self.t1_max_embed,
                                                self.t2_min_embed,
                                                self.t2_max_embed)
            neg_tensor2 = unit_cube.lambda_batch_log_1minus_prob(
                self.join_min, self.join_max, self.meet_min, self.meet_max,
                self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
                self.t2_max_embed)
            neg_tensor1 = tf.multiply(neg_tensor1,
                                      tf.cast(self.disjoint, tf.float32))
            neg_tensor2 = tf.multiply(
                neg_tensor2, tf.cast(tf.logical_not(self.disjoint),
                                     tf.float32))
            train_neg_prob = neg_tensor1 + neg_tensor2

        self.temperature_update = tf.assign_sub(self.temperature,
                                                FLAGS.decay_rate)
        # train_neg_prob = tf_utils.slicing_where(condition=self.disjoint,
        #                                         full_input=([self.join_min, self.join_max, self.meet_min, self.meet_max,
        #                                                      self.t1_min_embed, self.t1_max_embed,
        #                                                      self.t2_min_embed, self.t2_max_embed]),
        #                                         true_branch=lambda x: unit_cube.lambda_zero(*x),
        #                                         false_branch=lambda x: unit_cube.lambda_batch_log_1minus_prob(*x))
        """calculate negative log prob when evaluating pairs. The lower, the better"""
        # when return hierarchical error, we return the negative log probability, the lower, the probability higher
        # if two things are disjoint, we return -tf.log(1e-8).
        _, _, _, _, self.eval_disjoint = unit_cube.calc_join_and_meet(
            self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
            self.t2_max_embed)
        eval_tensor1 = unit_cube.lambda_hierarchical_error_upper(
            self.t1_min_embed, self.t1_max_embed, self.t2_min_embed,
            self.t2_max_embed)
        eval_tensor2 = unit_cube.lambda_batch_log_prob(self.t1_min_embed,
                                                       self.t1_max_embed,
                                                       self.t2_min_embed,
                                                       self.t2_max_embed)
        self.eval_prob = tf.where(self.eval_disjoint, eval_tensor1,
                                  eval_tensor2)
        # self.eval_prob = tf_utils.slicing_where(condition = self.disjoint,
        #                                         full_input = [self.join_min, self.join_max, self.meet_min, self.meet_max,
        #                                                       self.t1_min_embed, self.t1_max_embed,
        #                                                       self.t2_min_embed, self.t2_max_embed],
        #                                         true_branch = lambda x: unit_cube.lambda_hierarchical_error_upper(*x),
        #                                         false_branch = lambda x: unit_cube.lambda_batch_log_prob(*x))
        """model marg prob loss"""
        if FLAGS.w2 > 0.0:
            self.marg_prob = tf.constant(data.margina_prob)
            kl_difference = unit_cube.calc_marginal_prob(
                self.marg_prob, self.min_embed, self.delta_embed)
            kl_difference = tf.reshape(kl_difference, [-1]) / self.vocab_size
            self.marg_loss = FLAGS.w2 * (tf.reduce_sum(kl_difference))
        else:
            self.marg_loss = tf.constant(0.0)
        """model cond prob loss"""
        self.pos = FLAGS.w1 * tf.multiply(train_pos_prob, self.label)
        self.neg = FLAGS.w1 * tf.multiply(train_neg_prob, (1 - self.label))
        if FLAGS.debug:
            self.pos_disjoint = tf.logical_and(tf.cast(self.label, tf.bool),
                                               self.disjoint)
            self.pos_overlap = tf.logical_and(tf.cast(self.label, tf.bool),
                                              tf.logical_not(self.disjoint))
            self.neg_disjoint = tf.logical_and(
                tf.logical_not(tf.cast(self.label, tf.bool)), self.disjoint)
            self.neg_overlap = tf.logical_and(
                tf.logical_not(tf.cast(self.label, tf.bool)),
                tf.logical_not(self.disjoint))
            self.pos_disjoint.set_shape([None])
            self.neg_disjoint.set_shape([None])
            self.pos_overlap.set_shape([None])
            self.neg_overlap.set_shape([None])
            self.pos = tf.Print(self.pos, [
                tf.reduce_mean(tf.boolean_mask(self.pos, self.pos_disjoint)),
                tf.reduce_sum(tf.cast(self.pos_disjoint, tf.int32))
            ], 'pos disjoint loss')
            self.pos = tf.Print(self.pos, [
                tf.reduce_mean(tf.boolean_mask(self.pos, self.pos_overlap)),
                tf.reduce_sum(tf.cast(self.pos_overlap, tf.int32))
            ], 'pos overlap loss')
            self.neg = tf.Print(self.neg, [
                tf.reduce_mean(tf.boolean_mask(self.neg, self.neg_disjoint)),
                tf.reduce_sum(tf.cast(self.neg_disjoint, tf.int32))
            ], 'neg disjoint loss')
            self.neg = tf.Print(self.neg, [
                tf.reduce_mean(tf.boolean_mask(self.neg, self.neg_overlap)),
                tf.reduce_sum(tf.cast(self.neg_overlap, tf.int32))
            ], 'neg overlap loss')
            self.pos = tf.Print(self.pos, [tf.reduce_sum(self.pos)],
                                'pos loss')
            self.neg = tf.Print(self.neg, [tf.reduce_sum(self.neg)],
                                'neg loss')
            self.pos = tf.Print(self.pos, [
                tf.reduce_mean(
                    tf.exp(-tf.boolean_mask(self.pos, self.pos_overlap)))
            ], 'pos conditional prob')
            self.neg = tf.Print(self.neg, [
                tf.reduce_mean(
                    tf.exp(-tf.boolean_mask(train_pos_prob, self.neg_overlap)))
            ], 'neg conditional prob')
            self.pos = tf.Print(self.pos, [
                tf.reduce_mean(self.min_embed),
                tf.reduce_mean(self.delta_embed)
            ], 'embedding mean')
            # self.neg = tf.Print(self.neg, [tf.reduce_mean(tf.exp(tf.boolean_mask(unit_cube.batch_log_prob(self.meet_min, self.meet_max), self.neg_overlap)))], 'neg joint prob')
            # self.neg = tf.Print(self.neg, [tf.reduce_mean(tf.exp(tf.boolean_mask(unit_cube.batch_log_prob(self.t1_min_embed, self.t1_max_embed), self.neg_overlap)))], 'neg marg prob')
            self.pos_nested = tf.logical_and(tf.cast(self.label, tf.bool),
                                             self.nested)
            self.neg_nested = tf.logical_and(
                tf.logical_not(tf.cast(self.label, tf.bool)), self.nested)
            self.pos_nested.set_shape([None])
            self.neg_nested.set_shape([None])
            self.pos = tf.Print(self.pos, [
                tf.reduce_mean(tf.boolean_mask(self.pos, self.pos_nested)),
                tf.reduce_sum(tf.cast(self.pos_nested, tf.int32))
            ], 'pos nested loss')
            self.neg = tf.Print(self.neg, [
                tf.reduce_mean(tf.boolean_mask(self.neg, self.neg_nested)),
                tf.reduce_sum(tf.cast(self.neg_nested, tf.int32))
            ], 'neg nested loss')





        self.cond_loss = tf.reduce_sum(self.pos) / (self.batch_size / 2) + \
                         tf.reduce_sum(self.neg) / (self.batch_size / 2)
        # self.cond_loss = tf.Print(self.cond_loss, [tf.reduce_sum(self.pos), tf.reduce_sum(self.neg)], 'pos and neg loss')
        # self.cond_loss = tf.Print(self.cond_loss, [tf.gradients(self.cond_loss, [self.min_embed, self.delta_embed])[0], self.min_embed, self.delta_embed], 'gradient')
        """model regurlization: make box to be poe-ish"""
        self.regularization = FLAGS.r1 * tf.reduce_sum(
            tf.abs(1 - self.min_embed - self.delta_embed)) / self.vocab_size
        """model final loss"""
        # self.cond_loss = tf.Print(self.cond_loss, [self.pos, self.neg])
        self.debug = tf.constant(0.0)
        self.loss = self.cond_loss + self.marg_loss + self.regularization
        # self.loss = self.cond_loss + self.marg_loss

        if not self.freeze_grad:
            grads = tf.gradients(self.loss, tf.trainable_variables())
            grad_norm = 0.0
            for g in grads:
                new_values = tf.clip_by_value(g.values, -0.5, 0.5)
                grad_norm += tf.reduce_sum(new_values * new_values)
            grad_norm = tf.sqrt(grad_norm)
            self.grad_norm = grad_norm