示例#1
0
def create_input_pipeline(el_mode, model_folder, filenames):
    tf.reset_default_graph()
    folder = config.base_folder+"data/tfrecords/" + args.experiment_name + ("/test/" if el_mode else "/train/")
    datasets = []
    for file in filenames:
        datasets.append(reader.test_input_pipeline([folder+file], args))

    input_handle_ph = tf.placeholder(tf.string, shape=[], name="input_handle_ph")
    iterator = tf.contrib.data.Iterator.from_string_handle(
        input_handle_ph, datasets[0].output_types, datasets[0].output_shapes)
    next_element = iterator.get_next()

    train_args = load_train_args(args.output_folder, "ensemble_eval")

    print("loading Model:", model_folder)
    #train_args.evaluation_script = True
    train_args.entity_extension = args.entity_extension
    model = Model(train_args, next_element)
    model.build()
    #print("model train_args:", model.args)
    #print("model checkpoint_folder:", model.args.checkpoints_folder)
    model.input_handle_ph = input_handle_ph
    model.restore_session("el" if el_mode else "ed")

    #iterators, handles = from_datasets_to_iterators_and_handles(model.sess, datasets)
    iterators = []
    handles = []
    for dataset in datasets:
        #iterator = dataset.make_initializable_iterator() # one shot iterators fits better here
        iterator = dataset.make_one_shot_iterator()
        iterators.append(iterator)
        handles.append(model.sess.run(iterator.string_handle()))
    return model, handles
def evaluate():

    ed_datasets, ed_names = train.create_el_ed_pipelines(
        gmonly_flag=True, filenames=args.ed_datasets, args=args)
    el_datasets, el_names = train.create_el_ed_pipelines(
        gmonly_flag=False, filenames=args.el_datasets, args=args)

    input_handle_ph = tf.placeholder(tf.string,
                                     shape=[],
                                     name="input_handle_ph")
    sample_dataset = ed_datasets[0] if ed_datasets != [] else el_datasets[0]
    iterator = tf.data.Iterator.from_string_handle(
        input_handle_ph, sample_dataset.output_types,
        sample_dataset.output_shapes)
    next_element = iterator.get_next()

    model = Model(train_args, next_element)
    model.build()
    model.input_handle_ph = input_handle_ph  # just for convenience so i can access it from everywhere
    print(tf.global_variables())
    if args.p_e_m_algorithm:
        model.final_scores = model.cand_entities_scores

    def ed_el_dataset_handles(sess, datasets):
        test_iterators = []
        test_handles = []
        for dataset in datasets:
            test_iterator = dataset.make_initializable_iterator()
            test_iterators.append(test_iterator)
            test_handles.append(sess.run(test_iterator.string_handle()))
        return test_iterators, test_handles

    for el_mode, datasets, names in zip([False, True],
                                        [ed_datasets, el_datasets],
                                        [ed_names, el_names]):
        if names == []:
            continue
        model.restore_session("el" if el_mode else "ed")
        #print_variables_values(model)

        with model.sess as sess:
            print("Evaluating {} datasets".format("EL" if el_mode else "ED"))
            iterators, handles = ed_el_dataset_handles(sess, datasets)
            compute_ed_el_scores(model,
                                 handles,
                                 names,
                                 iterators,
                                 el_mode=el_mode)
示例#3
0
def train():
    training_dataset = create_training_pipelines(args)

    ed_datasets, ed_names = create_el_ed_pipelines(gmonly_flag=True,
                                                   filenames=args.ed_datasets,
                                                   args=args)
    el_datasets, el_names = create_el_ed_pipelines(gmonly_flag=False,
                                                   filenames=args.el_datasets,
                                                   args=args)

    input_handle_ph = tf.placeholder(tf.string,
                                     shape=[],
                                     name="input_handle_ph")
    iterator = tf.data.Iterator.from_string_handle(
        input_handle_ph, training_dataset.output_types,
        training_dataset.output_shapes)
    next_element = iterator.get_next()
    #print(next_element)

    if args.ablations:
        from model.model_ablations import Model
    else:
        from model.model import Model
    model = Model(args, next_element)
    model.build()
    model.input_handle_ph = input_handle_ph  # just for convenience so i can access it from everywhere
    #print(tf.global_variables())

    tf_writers = tensorboard_writers(model.sess.graph)
    model.tf_writers = tf_writers  # for accessing convenience

    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
    # and used to feed the `handle` placeholder.
    with model.sess as sess:

        def ed_el_dataset_handles(datasets):
            test_iterators = []
            test_handles = []
            for dataset in datasets:
                test_iterator = dataset.make_initializable_iterator()
                test_iterators.append(test_iterator)
                test_handles.append(sess.run(test_iterator.string_handle()))
            return test_iterators, test_handles

        training_iterator = training_dataset.make_one_shot_iterator()
        training_handle = sess.run(training_iterator.string_handle())

        ed_iterators, ed_handles = ed_el_dataset_handles(ed_datasets)
        el_iterators, el_handles = ed_el_dataset_handles(el_datasets)

        # Loop forever, alternating between training and validation.
        best_ed_score = 0
        best_el_score = 0
        termination_ed_score = 0
        termination_el_score = 0
        nepoch_no_imprv = 0  # for early stopping
        train_step = 0
        while True:
            total_train_loss = 0
            # for _ in range(args.steps_before_evaluation):          # for training based on training steps
            wall_start = time.time()
            while ((time.time() - wall_start) / 60) <= args.evaluation_minutes:
                train_step += 1
                if args.ffnn_l2maxnorm:
                    sess.run(model.ffnn_l2normalization_op_list)
                _, loss = sess.run(
                    [model.train_op, model.loss],
                    feed_dict={
                        input_handle_ph: training_handle,
                        model.dropout: args.dropout,
                        model.lr: model.args.lr
                    })
                total_train_loss += loss

            args.eval_cnt += 1
            summary = tf.Summary(value=[
                tf.Summary.Value(tag="total_train_loss",
                                 simple_value=total_train_loss)
            ])
            tf_writers["train"].add_summary(summary, args.eval_cnt)

            print("args.eval_cnt = ", args.eval_cnt)
            summary = sess.run(model.merged_summary_op)
            tf_writers["train"].add_summary(summary, args.eval_cnt)

            wall_start = time.time()
            comparison_ed_score = comparison_el_score = -0.1
            if ed_names:
                print("Evaluating ED datasets")
                ed_scores = compute_ed_el_scores(model,
                                                 ed_handles,
                                                 ed_names,
                                                 ed_iterators,
                                                 el_mode=False)
                comparison_ed_score = np.mean(
                    np.array(ed_scores)[args.ed_val_datasets])
            if el_names:
                print("Evaluating EL datasets")
                el_scores = compute_ed_el_scores(model,
                                                 el_handles,
                                                 el_names,
                                                 el_iterators,
                                                 el_mode=True)
                comparison_el_score = np.mean(
                    np.array(el_scores)[args.el_val_datasets])
            print("Evaluation duration in minutes: ",
                  (time.time() - wall_start) / 60)

            #comparison_ed_score = (ed_scores[1] + ed_scores[4]) / 2   # aida_dev + acquaint
            #comparison_score = ed_scores[1]  # aida_dev
            if model.args.lr_decay > 0:
                model.args.lr *= model.args.lr_decay  # decay learning rate
            text = ""
            best_ed_flag = False
            best_el_flag = False
            # otherwise not significant improvement 75.2 to 75.3 micro_f1 of aida_dev
            if comparison_ed_score >= best_ed_score + 0.1:  # args.improvement_threshold:
                text = "- new best ED score!" + " prev_best= " + str(best_ed_score) +\
                       " new_best= " + str(comparison_ed_score)
                best_ed_flag = True
                best_ed_score = comparison_ed_score
            if comparison_el_score >= best_el_score + 0.1:  #args.improvement_threshold:
                text += "- new best EL score!" + " prev_best= " + str(best_el_score) +\
                       " new_best= " + str(comparison_el_score)
                best_el_flag = True
                best_el_score = comparison_el_score
            if best_ed_flag or best_el_flag:  # keep checkpoint
                print(text)
                if args.nocheckpoints is False:
                    model.save_session(args.eval_cnt, best_ed_flag,
                                       best_el_flag)
            # check for termination now.
            if comparison_ed_score >= termination_ed_score + args.improvement_threshold\
                    or comparison_el_score >= termination_el_score + args.improvement_threshold:
                print("significant improvement. reset termination counter")
                termination_ed_score = comparison_ed_score
                termination_el_score = comparison_el_score
                nepoch_no_imprv = 0
            else:
                nepoch_no_imprv += 1
                if nepoch_no_imprv >= args.nepoch_no_imprv:
                    print("- early stopping {} epochs without "
                          "improvement".format(nepoch_no_imprv))
                    terminate()
                    break
示例#4
0
    def __init__(self, train_args, args):
        self.args = args
        # input pipeline
        self.streaming_samples = StreamingSamples()
        ds = tf.data.Dataset.from_generator(
            self.streaming_samples.gen,
            (
                tf.int64,
                tf.int64,
                tf.int64,
                tf.int64,  #words, words_len, chars, chars_len
                tf.int64,
                tf.int64,
                tf.int64,  # begin_span, end_span, span_len
                tf.int64,
                tf.float32,
                tf.int64
            ),  #cand_entities, cand_entities_scores, cand_entities_len
            (tf.TensorShape([None]), tf.TensorShape(
                []), tf.TensorShape([None, None]), tf.TensorShape([None]),
             tf.TensorShape([None]), tf.TensorShape([None]), tf.TensorShape(
                 []), tf.TensorShape([None, None]), tf.TensorShape(
                     [None, None]), tf.TensorShape([None])))
        next_element = ds.make_one_shot_iterator().get_next()
        # batch size = 1   i expand the dims now to match the training that has batch dimension
        next_element = [tf.expand_dims(t, 0) for t in next_element]
        next_element = [
            None, *next_element[:-1], None, next_element[-1], None, None, None,
            None
        ]

        # restore model
        print("loading Model:", train_args.output_folder)
        model = Model(train_args, next_element)
        model.build()
        checkpoint_path = model.restore_session("el" if args.el_mode else "ed")
        self.model = model
        if args.hardcoded_thr:
            self.thr = args.hardcoded_thr
            print("threshold used:", self.thr)
        else:
            # optimal threshold recovery from log files.
            # based on the checkpoint selected look at the log file for threshold (otherwise recompute it)
            self.thr = retrieve_optimal_threshold_from_logfile(
                train_args.output_folder, checkpoint_path, args.el_mode)
            print("optimal threshold selected = ", self.thr)

        if args.running_mode == "el_mode":
            args.el_mode = True
        elif args.running_mode == "ed_mode":
            args.el_mode = False

        # convert text to tensors for the NN
        with open(args.experiment_folder + "word_char_maps.pickle",
                  'rb') as handle:
            self.word2id, _, self.char2id, _, _, _ = pickle.load(handle)

        self.wikiid2nnid = load_wikiid2nnid(
            extension_name=args.entity_extension)
        self.nnid2wikiid = reverse_dict(self.wikiid2nnid, unique_values=True)
        _, self.wiki_id_name_map = load_wiki_name_id_map()

        with open(args.experiment_folder + "prepro_args.pickle",
                  'rb') as handle:
            self.prepro_args = pickle.load(handle)
            if args.lowercase_spans_pem:
                self.prepro_args.lowercase_p_e_m = True
                self.prepro_args.lowercase_spans = True
        print("prepro_args:", self.prepro_args)
        self.prepro_args.persons_coreference = args.persons_coreference
        self.prepro_args.persons_coreference_merge = args.persons_coreference_merge
        self.fetchFilteredCoreferencedCandEntities = FetchFilteredCoreferencedCandEntities(
            self.prepro_args)
        prepro_util.args = self.prepro_args

        self.special_tokenized_words = {"``", '"', "''"}
        self.special_words_assertion_errors = 0
        self.gm_idx_errors = 0
        if self.args.el_with_stanfordner_and_our_ed:
            from nltk.tag import StanfordNERTagger
            self.st = StanfordNERTagger(
                '../data/stanford_core_nlp/stanford-ner-2018-02-27/classifiers/english.all.3class.distsim.crf.ser.gz',
                '../data/stanford_core_nlp/stanford-ner-2018-02-27/stanford-ner.jar',
                encoding='utf-8')
        self.from_myspans_to_given_spans_map_errors = 0