def main(argv):
    ## TODO gross
    if ('transformer' in FLAGS.text_encoder or 'glu' in FLAGS.text_encoder) and FLAGS.token_dim == 0:
        FLAGS.token_dim = FLAGS.embed_dim-(2*FLAGS.position_dim)
    # print flags:values in alphabetical order
    # print ('\n'.join(sorted(["%s : %s" % (str(k), str(v)) for k, v in FLAGS.__dict__['__flags'].iteritems()])))

    if FLAGS.vocab_dir == '':
        print('Error: Must supply input data generated from tsv_to_tfrecords.py')
        sys.exit(1)
    if FLAGS.positive_train == '':
        print('Error: Must supply either positive_train')
        sys.exit(1)

    # read in str <-> int vocab maps
    with open(FLAGS.vocab_dir + '/rel.txt', 'r') as f:
        kb_str_id_map = {l.split('\t')[0]: int(l.split('\t')[1].strip()) for l in f.readlines()}
        kb_id_str_map = {i: s for s, i in kb_str_id_map.iteritems()}
        kb_vocab_size = FLAGS.kb_vocab_size
    with open(FLAGS.vocab_dir + '/token.txt', 'r') as f:
        token_str_id_map = {l.split('\t')[0]: int(l.split('\t')[1].strip()) for l in f.readlines()}
        if FLAGS.start_end:
            if '<START>' not in token_str_id_map: token_str_id_map['<START>'] = len(token_str_id_map)
            if '<END>' not in token_str_id_map: token_str_id_map['<END>'] = len(token_str_id_map)
        token_id_str_map = {i: s for s, i in token_str_id_map.iteritems()}
        token_vocab_size = len(token_id_str_map)


    with open(FLAGS.vocab_dir + '/entities.txt', 'r') as f:
        entity_str_id_map = {l.split('\t')[0]: int(l.split('\t')[1].strip()) for l in f.readlines()}
        entity_id_str_map = {i: s for s, i in entity_str_id_map.iteritems()}
        entity_vocab_size = len(entity_id_str_map)
    with open(FLAGS.vocab_dir + '/ep.txt', 'r') as f:
        ep_str_id_map = {l.split('\t')[0]: int(l.split('\t')[1].strip()) for l in f.readlines()}
        ep_id_str_map = {i: s for s, i in ep_str_id_map.iteritems()}
        ep_vocab_size = len(ep_id_str_map)

    if FLAGS.ner_train != '':
        with open(FLAGS.vocab_dir + '/ner_labels.txt', 'r') as f:
            ner_label_str_id_map = {l.split('\t')[0]: int(l.split('\t')[1].strip()) for l in f.readlines()}
            if FLAGS.start_end:
                if '<START>' not in ner_label_str_id_map: ner_label_str_id_map['<START>'] = len(ner_label_str_id_map)
                if '<END>' not in ner_label_str_id_map: ner_label_str_id_map['<END>'] = len(ner_label_str_id_map)
            ner_label_id_str_map = {i: s for s, i in ner_label_str_id_map.iteritems()}
            ner_label_vocab_size = len(ner_label_id_str_map)
    else:
        ner_label_id_str_map = {}
        ner_label_str_id_map = {}
        ner_label_vocab_size = 1
    position_vocab_size = (2 * FLAGS.max_seq)

    label_weights = None
    if FLAGS.label_weights != '':
        with open(FLAGS.label_weights, 'r') as f:
            lines = [l.strip().split('\t') for l in f]
            label_weights = {kb_str_id_map[k]: float(v) for k, v in lines}

    ep_kg_labels = None
    if FLAGS.kg_label_file != '':
        kg_in_file = gzip.open(FLAGS.kg_label_file, 'rb') if FLAGS.kg_label_file.endswith('gz') else open(FLAGS.kg_label_file, 'r')
        lines = [l.strip().split() for l in kg_in_file.readlines()]
        eps = [('%s::%s' % (l[0], l[1]), l[2]) for l in lines]
        ep_kg_labels = defaultdict(set)
        [ep_kg_labels[ep_str_id_map[_ep]].add(pid) for _ep, pid in eps if _ep in ep_str_id_map]
        print('Ep-Kg label map size %d ' % len(ep_kg_labels))
        kg_in_file.close()

    e1_e2_ep_map = {} #{(entity_str_id_map[ep_str.split('::')[0]], entity_str_id_map[ep_str.split('::')[1]]): ep_id
                      #for ep_id, ep_str in ep_id_str_map.iteritems()}
    ep_e1_e2_map = {} #{ep: e1_e2 for e1_e2, ep in e1_e2_ep_map.iteritems()}

    # get entity <-> type maps for sampling negatives
    entity_type_map, type_entity_map = {}, defaultdict(list)
    if FLAGS.type_file != '':
        with open(FLAGS.type_file, 'r') as f:
            entity_type_map = {entity_str_id_map[l.split('\t')[0]]: l.split('\t')[1].strip().split(',') for l in
                               f.readlines() if l.split('\t')[0] in entity_str_id_map}
            for entity, type_list in entity_type_map.iteritems():
                for t in type_list:
                    type_entity_map[t].append(entity)
            # filter
            type_entity_map = {k: v for k, v in type_entity_map.iteritems() if len(v) > 1}
            valid_types = set([t for t in type_entity_map.iterkeys()])
            entity_type_map = {k: [t for t in v if t in valid_types] for k, v in entity_type_map.iteritems()}
            entity_type_map = {k: v for k, v in entity_type_map.iteritems() if len(v) > 1}

    string_int_maps = {'kb_str_id_map': kb_str_id_map, 'kb_id_str_map': kb_id_str_map,
                        'token_str_id_map': token_str_id_map, 'token_id_str_map': token_id_str_map,
                        'entity_str_id_map': entity_str_id_map, 'entity_id_str_map': entity_id_str_map,
                        'ep_str_id_map': ep_str_id_map, 'ep_id_str_map': ep_id_str_map,
                        'ner_label_str_id_map': ner_label_str_id_map, 'ner_label_id_str_map': ner_label_id_str_map,
                        'e1_e2_ep_map': e1_e2_ep_map, 'ep_e1_e2_map': ep_e1_e2_map, 'ep_kg_labels': ep_kg_labels,
                        'label_weights': label_weights}

    word_embedding_matrix = load_pretrained_embeddings(token_str_id_map, FLAGS.embeddings, FLAGS.token_dim, token_vocab_size)
    entity_embedding_matrix = load_pretrained_embeddings(entity_str_id_map, FLAGS.entity_embeddings, FLAGS.embed_dim, entity_vocab_size)

    with tf.Graph().as_default():
        tf.set_random_seed(FLAGS.random_seed)
        np.random.seed(FLAGS.random_seed)
        random.seed(FLAGS.random_seed)

        if FLAGS.doc_filter:
            train_percent = FLAGS.train_dev_percent
            with open(FLAGS.doc_filter, 'r') as f:
                doc_filter_ids = [l.strip() for l in f]
            shuffle(doc_filter_ids)
            split_idx = int(len(doc_filter_ids) * train_percent)
            dev_ids, train_ids = set(doc_filter_ids[:split_idx]), set(doc_filter_ids[split_idx:])
            # ids in dev_ids will be filtered from dev, same for train_ids
            print('Splitting dev data %d documents for train and %d documents for dev' % (len(dev_ids), len(train_ids)))
        else:
            dev_ids, train_ids = None, None

        # have seperate batchers for positive and negative train/test
        batcher = InMemoryBatcher if FLAGS.in_memory else Batcher
        pos_dist_supervision_batcher = batcher(FLAGS.positive_dist_train, FLAGS.kb_epochs, FLAGS.max_seq, FLAGS.kb_batch) \
            if FLAGS.positive_dist_train else None
        neg_dist_supervision_batcher = batcher(FLAGS.negative_dist_train, FLAGS.kb_epochs, FLAGS.max_seq, FLAGS.kb_batch) \
            if FLAGS.negative_dist_train else None

        positive_train_batcher = batcher(FLAGS.positive_train, FLAGS.text_epochs, FLAGS.max_seq, FLAGS.text_batch)
        negative_train_batcher = batcher(FLAGS.negative_train, FLAGS.text_epochs, FLAGS.max_seq, FLAGS.text_batch)

        positive_test_batcher = InMemoryBatcher(FLAGS.positive_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.positive_test else None
        negative_test_batcher = InMemoryBatcher(FLAGS.negative_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.negative_test else None
        positive_test_test_batcher = InMemoryBatcher(FLAGS.positive_test_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.positive_test_test else None

        negative_test_test_batcher = InMemoryBatcher(FLAGS.negative_test_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.negative_test_test else None
        ner_test_batcher = NERInMemoryBatcher(FLAGS.ner_test, 1, FLAGS.max_seq, 10) if FLAGS.ner_test else None
        ner_batcher = NERBatcher(FLAGS.ner_train, FLAGS.text_epochs, FLAGS.max_seq, FLAGS.ner_batch) \
            if FLAGS.ner_train != '' else None

        # initialize model
        if 'multi' in FLAGS.model_type and 'label' in FLAGS.model_type:
            model_type = MultiLabelClassifier
        elif 'entity' in FLAGS.model_type and 'binary' in FLAGS.model_type:
            model_type = EntityBinary
        else:
            model_type = ClassifierModel
        print('Model type: %s ' % FLAGS.model_type)
        model = model_type(ep_vocab_size, entity_vocab_size, kb_vocab_size, token_vocab_size, position_vocab_size,
                           ner_label_vocab_size, word_embedding_matrix, entity_embedding_matrix, string_int_maps, FLAGS)

        # optimization
        learning_rate = tf.train.exponential_decay(FLAGS.lr, model.global_step, FLAGS.lr_decay_steps,
                                                   FLAGS.lr_decay_rate, staircase=False, name=None)
        print ('Optimizer: %s' % FLAGS.optimizer)
        if FLAGS.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=FLAGS.epsilon, beta2=FLAGS.beta2)
        elif FLAGS.optimizer == 'lazyadam':
            optimizer = tf.contrib.opt.LazyAdamOptimizer(learning_rate=learning_rate, epsilon=FLAGS.epsilon, beta2=FLAGS.beta2)
        elif FLAGS.optimizer == 'adagrad':
            optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate)
        elif FLAGS.optimizer == 'momentum':
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=.9, )
        else:
            print('%s is not a supported optimizer type' % FLAGS.optimizer)
            sys.exit(1)

        tvars = tf.trainable_variables()
        if FLAGS.clip_norm > 0:
            if FLAGS.freeze_noise:
                noise_vars = [k for k in tvars if 'noise_classifier' not in k.name]
                if not noise_vars:
                    print('Filtering noise variables removed full graph. Is this the wrong FLAGS.model_type?')
                    sys.exit(1)
                for k in noise_vars:
                    print(k.name)
                grads, _ = tf.clip_by_global_norm(tf.gradients(model.loss, noise_vars), FLAGS.clip_norm)
                train_op = optimizer.apply_gradients(zip(grads, noise_vars), global_step=model.global_step)
            else:
                grads, _ = tf.clip_by_global_norm(tf.gradients(model.loss, tvars), FLAGS.clip_norm)
                if FLAGS.noise_std > 0.0:
                    print('Adding noise to gradients with mean: 0.0 std: %0.3f' % FLAGS.noise_std)
                    noisy_gradients = []
                    for gradient in grads:
                        if gradient is None:
                            noisy_gradients.append(None)
                            continue
                        if isinstance(gradient, tf.IndexedSlices):
                            gradient_shape = gradient.dense_shape
                        else:
                            gradient_shape = gradient.get_shape()
                        std = FLAGS.noise_std
                        scale = tf.sqrt(tf.cast(1+model.global_step, tf.float32))
                        noise = tf.truncated_normal(gradient_shape, stddev=std) / scale
                        noisy_gradients.append(gradient + noise)
                    grads = noisy_gradients
                train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=model.global_step)

        else:
            train_op = optimizer.minimize(model.loss, global_step=model.global_step)

        emma = tf.train.ExponentialMovingAverage(decay=0.9999, num_updates=model.global_step)
        emma_op = emma.apply(tvars)
        with tf.control_dependencies([train_op]):
            train_op = tf.group(emma_op)
        ner_train_op = None
        if FLAGS.ner_train != '' and FLAGS.ner_prob > 0:
            ner_grads, _ = tf.clip_by_global_norm(tf.gradients(model.ner_loss, tvars), FLAGS.clip_norm)
            ner_train_op = optimizer.apply_gradients(zip(ner_grads, tvars), global_step=model.global_step)
            with tf.control_dependencies([ner_train_op]):
                ner_train_op = tf.group(emma_op)

        assign_shadow_ops = []
        for t in tvars:
            v_scope_name = t.name.split(":")[0]
            v_scope = '/'.join(v_scope_name.split("/")[:-1])
            v_name = v_scope_name.split("/")[-1]
            try:
                with tf.variable_scope(v_scope, reuse=True):
                    v = tf.get_variable(v_name, t.shape)
                    shadow_v = emma.average(v)
                    assign_shadow_ops.append(v.assign(shadow_v))
            except:
                print('Couldnt get %s' % v_scope_name)

        # restore only variables that exist in the checkpoint - needed to pre-train big models with small models
        if FLAGS.load_model != '':
            reader = tf.train.NewCheckpointReader(FLAGS.load_model)
            cp_list = set([key for key in reader.get_variable_to_shape_map()])
            # if variable does not exist in checkpoint or sizes do not match, dont load
            r_vars = [k for k in tf.global_variables() if k.name.split(':')[0] in cp_list
                        and k.get_shape() == reader.get_variable_to_shape_map()[k.name.split(':')[0]]]
            if len(cp_list) != len(r_vars):
                print('[Warning]: not all variables loaded from file')
                # print('\n'.join(sorted(set(cp_list)-set(r_vars))))
            saver = tf.train.Saver(var_list=r_vars)
        else:
            saver = tf.train.Saver()
        sv = tf.train.Supervisor(logdir=FLAGS.logdir if FLAGS.save_model != '' else None,
                                 global_step=model.global_step,
                                 saver=None,
                                 save_summaries_secs=0,
                                 save_model_secs=0, )

        with sv.managed_session(FLAGS.master,
                                config=tf.ConfigProto(
                                    # log_device_placement=True,
                                    allow_soft_placement=True
                                )) as sess:
            if FLAGS.load_model != '':
                print("Deserializing model: %s" % FLAGS.load_model)
                saver.restore(sess, FLAGS.load_model)

            threads = tf.train.start_queue_runners(sess=sess)
            fb15k_eval = None
            tac_eval = None

            if FLAGS.in_memory:
                if positive_train_batcher: positive_train_batcher.load_all_data(sess, doc_filter=train_ids)
                if negative_train_batcher: negative_train_batcher.load_all_data(sess, doc_filter=train_ids)
            if positive_test_batcher: positive_test_batcher.load_all_data(sess, doc_filter=dev_ids)
            if negative_test_batcher: negative_test_batcher.load_all_data(sess, doc_filter=dev_ids)
            if positive_test_test_batcher: positive_test_test_batcher.load_all_data(sess)
            if negative_test_test_batcher: negative_test_test_batcher.load_all_data(sess)
            if ner_test_batcher: ner_test_batcher.load_all_data(sess)#, test_batches)
            if FLAGS.mode == 'train':
                save_path = '%s/%s' % (FLAGS.logdir, FLAGS.save_model) if FLAGS.save_model != '' else None
                train_model(model, pos_dist_supervision_batcher, neg_dist_supervision_batcher,
                            positive_train_batcher, negative_train_batcher, ner_batcher, sv, sess, saver,
                            train_op, ner_train_op, string_int_maps,
                            positive_test_batcher, negative_test_batcher, ner_test_batcher,
                            positive_test_test_batcher, negative_test_test_batcher,
                            tac_eval, fb15k_eval,
                            FLAGS.text_prob, FLAGS.text_weight,
                            FLAGS.log_every, FLAGS.eval_every, FLAGS.neg_samples,
                            save_path, FLAGS.kb_pretrain, FLAGS.max_decrease_epochs,
                            FLAGS.max_steps, assign_shadow_ops)

            elif FLAGS.mode == 'evaluate':
                print('Evaluating')
                results, _, threshold_map = relation_eval(sess, model, FLAGS, positive_test_batcher,
                                                          negative_test_batcher, string_int_maps)
                if positive_test_test_batcher and negative_test_test_batcher:
                    if FLAGS.export_file != '':
                        export_predictions(sess, model, FLAGS, positive_test_test_batcher, negative_test_test_batcher,
                                           string_int_maps, FLAGS.export_file, threshold_map=threshold_map)
                    else:
                        results, _, _ = relation_eval(sess, model, FLAGS, positive_test_test_batcher,
                                                      negative_test_test_batcher, string_int_maps, threshold_map=threshold_map)
                if ner_test_batcher:
                    ner_eval(ner_test_batcher, sess, model, FLAGS, string_int_maps)
                print (results)
            elif FLAGS.mode == 'export' and FLAGS.export_file != '':
                print('Exporting predictions')
                export_scores(sess, model, FLAGS, positive_test_test_batcher, negative_test_test_batcher,
                              string_int_maps, FLAGS.export_file)
            elif FLAGS.mode == 'attention' and FLAGS.export_file != '':
                print('Exporting attention weights')
                export_attentions(sess, model, FLAGS, positive_test_batcher, negative_test_batcher,
                              string_int_maps, FLAGS.export_file)
            else:
                print('Error: "%s" is not a valid mode' % FLAGS.mode)
                sys.exit(1)

            sv.coord.request_stop()
            sv.coord.join(threads)
            sess.close()
def train_model(model, pos_dist_supervision_batcher, neg_dist_supervision_batcher,
                positive_train_batcher, negative_train_batcher, ner_batcher, sv, sess, saver, train_op, ner_train_op,
                string_int_maps, positive_test_batcher, negative_test_batcher, ner_test_batcher,
                positive_test_test_batcher, negative_test_test_batcher, tac_eval, fb15k_eval,
                text_prob, text_weight,
                log_every, eval_every, neg_samples, save_path,
                kb_pretrain=0, max_decrease_epochs=5, max_steps=-1, assign_shadow_ops=tf.no_op()):
    step = 0.
    examples = 0.
    losses = [1.0]
    loss_idx = 1
    loss_avg_len = 100
    best_score = 0.0
    decrease_epochs = 0
    last_update = time.time()
    eval_every = max(1, int(eval_every/float(FLAGS.text_batch)))
    ner_losses = [1.0]
    ner_loss_idx = 1
    ner_prob = FLAGS.ner_prob
    ner_decay = 1e-4

    # ner_eval(ner_test_batcher, sess, model, token_str_id_map)
    # sys.exit(1)

    print ('Starting training, eval every: %d' % eval_every)
    while not sv.should_stop() and (max_steps <= 0 or step < max_steps) and (decrease_epochs <= max_decrease_epochs):
        # try:
            if FLAGS.anneal_ner and step > 5000:
                ner_prob = ner_prob * 1/(1 + ner_decay * step)
            do_ner_update = random.uniform(0, 1) <= ner_prob
            # eval / serialize
            if step > 0 and step % eval_every == 0:
                sess.run(assign_shadow_ops)
                if positive_test_batcher:
                    avg_p, flat_scores, threshold_map = relation_eval(sess, model, FLAGS,
                                                                      positive_test_batcher, negative_test_batcher,
                                                                      string_int_maps, message='Dev')
                if ner_test_batcher and FLAGS.ner_prob > 0:
                    ner_f1, ner_p = ner_eval(ner_test_batcher, sess, model, FLAGS, string_int_maps)

                keep_score = avg_p if FLAGS.ner_prob < 1.0 else ner_f1
                if keep_score > best_score:
                    decrease_epochs = 0
                    best_score = keep_score
                    if save_path:
                        saved_path = saver.save(sess, save_path)
                        print("Serialized model: %s" % saved_path)
                    if positive_test_batcher:
                        # if FLAGS.analyze_errors > 0: analyze_errors(flat_scores, string_int_maps)
                        if positive_test_test_batcher and negative_test_test_batcher:
                            print('Evaluating Test Test')
                            avg_p, _, _ = relation_eval(sess, model, FLAGS,
                                                        positive_test_test_batcher, negative_test_test_batcher,
                                                        string_int_maps, message='Test', threshold_map=threshold_map)
                # if model doesnt improve after max_decrease_epochs, stop training
                elif FLAGS.ner_prob == 1.0 or not do_ner_update:
                    decrease_epochs += 1
                    print('\nEval decreased for %d epochs out of %d max epochs. Best: %2.2f\n'
                          % (decrease_epochs, max_decrease_epochs, best_score))

            # ner update always
            if do_ner_update:
                ner_batch = ner_batcher.next_batch(sess)
                feed_dict, ner_batch_size = ner_feed_dict(ner_batch, model, FLAGS, string_int_maps=string_int_maps)
                _, global_step, ner_loss, = sess.run([ner_train_op, model.global_step, model.ner_loss], feed_dict=feed_dict)
                if len(ner_losses) < loss_avg_len:
                    ner_losses.append(np.mean(ner_loss))
                    ner_loss_idx += 1
                else:
                    ner_loss_idx = 0 if ner_loss_idx >= (loss_avg_len - 1) else ner_loss_idx + 1
                    ner_losses[ner_loss_idx] = np.mean(ner_loss)
            # relex update
            else:
                # dist supervision update
                if step < kb_pretrain or random.uniform(0, 1) > text_prob:
                    if FLAGS.pos_prob >= random.uniform(0, 1):
                        feed_dict, batch_size, doc_ids = batch_feed_dict(pos_dist_supervision_batcher, sess, model,
                                                                FLAGS, string_int_maps=string_int_maps)
                    else:
                        feed_dict, batch_size, doc_ids = batch_feed_dict(neg_dist_supervision_batcher, sess, model,
                                                                         FLAGS, string_int_maps=string_int_maps)
                else:
                    # text_update
                    if FLAGS.pos_prob >= random.uniform(0, 1):
                        feed_dict, batch_size, doc_ids = batch_feed_dict(positive_train_batcher, sess, model,
                                                                         FLAGS, string_int_maps=string_int_maps)
                    else:
                        feed_dict, batch_size, doc_ids = batch_feed_dict(negative_train_batcher, sess, model,
                                                                         FLAGS, string_int_maps=string_int_maps)
                    feed_dict[model.loss_weight] = FLAGS.text_weight

                feed_dict[model.noise_weight] = FLAGS.variance_min

                _, global_step, loss, = sess.run([train_op, model.global_step, model.loss], feed_dict=feed_dict)
                examples += batch_size
                loss /= batch_size

                # update loss moving avg
                if len(losses) < loss_avg_len:
                    losses.append(loss)
                    loss_idx += 1
                else:
                    loss_idx = 0 if loss_idx >= (loss_avg_len - 1) else loss_idx + 1
                    losses[loss_idx] = loss

            # log
            if step % log_every == 0:
                steps_per_sec = log_every / (time.time() - last_update)
                examples_per_sec = examples / (time.time() - last_update)
                examples = 0.

                sys.stdout.write('\rstep: %d \t avg loss: %.4f \t ner loss: %.4f'
                                 '\t steps/sec: %.4f \t text examples/sec: %5.2f' %
                                 (step, float(np.mean(losses)), float(np.mean(ner_losses)), steps_per_sec, examples_per_sec))
                sys.stdout.flush()
                last_update = time.time()
            step += 1

    print ('\n Done training')
    if best_score > 0.0: print('Best Score: %2.2f' % best_score)
Example #3
0
def main(argv):
    if ('transformer' in FLAGS.text_encoder
            or 'glu' in FLAGS.text_encoder) and FLAGS.token_dim == 0:
        FLAGS.token_dim = FLAGS.embed_dim - (2 * FLAGS.position_dim)
        # print flags:values in alphabetical order
    print('\n'.join(
        sorted([
            "%s : %s" % (str(k), str(v))
            for k, v in FLAGS.__dict__['__flags'].iteritems()
        ])))

    if FLAGS.vocab_dir == '':
        print(
            'Error: Must supply input data generated from tsv_to_tfrecords.py')
        sys.exit(1)

    position_vocab_size = (2 * FLAGS.max_seq)

    # read in str <-> int vocab maps
    with open(FLAGS.vocab_dir + '/rel.txt', 'r') as f:
        kb_str_id_map = {
            l.split('\t')[0]: int(l.split('\t')[1].strip())
            for l in f.readlines()
        }
        kb_id_str_map = {i: s for s, i in kb_str_id_map.iteritems()}
        kb_vocab_size = FLAGS.kb_vocab_size
    with open(FLAGS.vocab_dir + '/token.txt', 'r') as f:
        token_str_id_map = {
            l.split('\t')[0]: int(l.split('\t')[1].strip())
            for l in f.readlines()
        }
        if FLAGS.start_end:
            if '<START>' not in token_str_id_map:
                token_str_id_map['<START>'] = len(token_str_id_map)
            if '<END>' not in token_str_id_map:
                token_str_id_map['<END>'] = len(token_str_id_map)
        token_id_str_map = {i: s for s, i in token_str_id_map.iteritems()}
        token_vocab_size = len(token_id_str_map)

    with open(FLAGS.vocab_dir + '/entities.txt', 'r') as f:
        entity_str_id_map = {
            l.split('\t')[0]: int(l.split('\t')[1].strip())
            for l in f.readlines()
        }
        entity_id_str_map = {i: s for s, i in entity_str_id_map.iteritems()}
        entity_vocab_size = len(entity_id_str_map)
    with open(FLAGS.vocab_dir + '/ep.txt', 'r') as f:
        ep_str_id_map = {
            l.split('\t')[0]: int(l.split('\t')[1].strip())
            for l in f.readlines()
        }
        ep_id_str_map = {i: s for s, i in ep_str_id_map.iteritems()}
        ep_vocab_size = len(ep_id_str_map)

    ep_kg_labels = None
    if FLAGS.kg_label_file != '':
        kg_in_file = gzip.open(
            FLAGS.kg_label_file,
            'rb') if FLAGS.kg_label_file.endswith('gz') else open(
                FLAGS.kg_label_file, 'r')
        lines = [l.strip().split() for l in kg_in_file.readlines()]
        eps = [('%s::%s' % (l[0], l[1]), l[2]) for l in lines]
        ep_kg_labels = defaultdict(set)
        [
            ep_kg_labels[ep_str_id_map[_ep]].add(pid) for _ep, pid in eps
            if _ep in ep_str_id_map
        ]
        print('Ep-Kg label map size %d ' % len(ep_kg_labels))
        kg_in_file.close()

    label_weights = None
    if FLAGS.label_weights != '':
        with open(FLAGS.label_weights, 'r') as f:
            lines = [l.strip().split('\t') for l in f]
            label_weights = {kb_str_id_map[k]: float(v) for k, v in lines}

    model_type = MultiLabelClassifier
    ner_label_id_str_map = {}
    ner_label_str_id_map = {}
    ner_label_vocab_size = 1
    e1_e2_ep_map = {
    }  # {(entity_str_id_map[ep_str.split('::')[0]], entity_str_id_map[ep_str.split('::')[1]]): ep_id
    # for ep_id, ep_str in ep_id_str_map.iteritems()}
    ep_e1_e2_map = {}  # {ep: e1_e2 for e1_e2, ep in e1_e2_ep_map.iteritems()}

    word_embedding_matrix = load_pretrained_embeddings(token_str_id_map,
                                                       FLAGS.embeddings,
                                                       FLAGS.token_dim,
                                                       token_vocab_size)
    entity_embedding_matrix = load_pretrained_embeddings(
        entity_str_id_map, FLAGS.entity_embeddings, FLAGS.embed_dim,
        entity_vocab_size)

    string_int_maps = {
        'kb_str_id_map': kb_str_id_map,
        'kb_id_str_map': kb_id_str_map,
        'token_str_id_map': token_str_id_map,
        'token_id_str_map': token_id_str_map,
        'entity_str_id_map': entity_str_id_map,
        'entity_id_str_map': entity_id_str_map,
        'ep_str_id_map': ep_str_id_map,
        'ep_id_str_map': ep_id_str_map,
        'ner_label_str_id_map': ner_label_str_id_map,
        'ner_label_id_str_map': ner_label_id_str_map,
        'e1_e2_ep_map': e1_e2_ep_map,
        'ep_e1_e2_map': ep_e1_e2_map,
        'ep_kg_labels': ep_kg_labels,
        'label_weights': label_weights
    }

    with tf.Graph().as_default():
        tf.set_random_seed(FLAGS.random_seed)
        np.random.seed(FLAGS.random_seed)
        random.seed(FLAGS.random_seed)

        if FLAGS.doc_filter:
            train_percent = FLAGS.train_dev_percent
            with open(FLAGS.doc_filter, 'r') as f:
                doc_filter_ids = [l.strip() for l in f]
            shuffle(doc_filter_ids)
            split_idx = int(len(doc_filter_ids) * train_percent)
            dev_ids, train_ids = set(doc_filter_ids[:split_idx]), set(
                doc_filter_ids[split_idx:])
            # ids in dev_ids will be filtered from dev, same for train_ids
            print(
                'Splitting dev data %d documents for train and %d documents for dev'
                % (len(dev_ids), len(train_ids)))
        else:
            dev_ids, train_ids = None, None

        positive_test_batcher = InMemoryBatcher(FLAGS.positive_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.positive_test else None
        negative_test_batcher = InMemoryBatcher(FLAGS.negative_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.negative_test else None
        positive_test_test_batcher = InMemoryBatcher(FLAGS.positive_test_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.positive_test_test else None

        # have seperate batchers for positive and negative train/test
        batcher = InMemoryBatcher if FLAGS.in_memory else Batcher
        positive_test_batcher = InMemoryBatcher(FLAGS.positive_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.positive_test else None
        negative_test_batcher = InMemoryBatcher(FLAGS.negative_test, 1, FLAGS.max_seq, FLAGS.text_batch) \
            if FLAGS.negative_test else None

        model = model_type(ep_vocab_size, entity_vocab_size, kb_vocab_size,
                           token_vocab_size, position_vocab_size,
                           ner_label_vocab_size, word_embedding_matrix,
                           entity_embedding_matrix, string_int_maps, FLAGS)

        # restore only variables that exist in the checkpoint - needed to pre-train big models with small models
        if FLAGS.load_model != '':
            reader = tf.train.NewCheckpointReader(FLAGS.load_model)
            cp_list = set([key for key in reader.get_variable_to_shape_map()])
            # if variable does not exist in checkpoint or sizes do not match, dont load
            r_vars = [
                k for k in tf.global_variables()
                if k.name.split(':')[0] in cp_list and k.get_shape() ==
                reader.get_variable_to_shape_map()[k.name.split(':')[0]]
            ]
            if len(cp_list) != len(r_vars):
                print('[Warning]: not all variables loaded from file')
                # print('\n'.join(sorted(set(cp_list)-set(r_vars))))
            saver = tf.train.Saver(var_list=r_vars)
        else:
            saver = tf.train.Saver()
        sv = tf.train.Supervisor(
            logdir=FLAGS.logdir if FLAGS.save_model != '' else None,
            global_step=model.global_step,
            saver=None,
            save_summaries_secs=0,
            save_model_secs=0,
        )

        with sv.managed_session(
                FLAGS.master,
                config=tf.ConfigProto(
                    # log_device_placement=True,
                    allow_soft_placement=True)) as sess:

            if positive_test_batcher:
                positive_test_batcher.load_all_data(sess, doc_filter=dev_ids)
            if negative_test_batcher:
                negative_test_batcher.load_all_data(sess, doc_filter=dev_ids)

            if FLAGS.load_model != '':
                print("Deserializing model: %s" % FLAGS.load_model)
                saver.restore(sess, FLAGS.load_model)

            import pdb
            pdb.set_trace()
            relation_eval(sess, model, FLAGS, positive_test_batcher,
                          negative_test_batcher, string_int_maps)