Exemplo n.º 1
0
    def __init__(self,
                 config,
                 sess,
                 checkpoint_saver,
                 best_model_saver,
                 restore_if_possible=True):
        self.config = config
        self.checkpoint_saver = checkpoint_saver
        self.best_model_saver = best_model_saver

        tf.gfile.MakeDirs(config.checkpoints_dir)
        if restore_if_possible and tf.gfile.Exists(config.progress):
            history, current_file, current_line = utils.load_pickle(
                config.progress, memoized=False)
            self.history = history
            self.unlabeled_data_reader = unlabeled_data.UnlabeledDataReader(
                config, current_file, current_line)
            utils.log(
                "Continuing from global step",
                dict(self.history[-1])["step"],
                "(lm1b file {:}, line {:})".format(current_file, current_line))
            self.checkpoint_saver.restore(
                sess, tf.train.latest_checkpoint(self.config.checkpoints_dir))
        else:
            utils.log("No previous checkpoint found - starting from scratch")
            self.history = []
            self.unlabeled_data_reader = (
                unlabeled_data.UnlabeledDataReader(config))
Exemplo n.º 2
0
Arquivo: cvt.py Projeto: bluesea0/ditk
def main(mode='train',
         model_name='chunking_model',
         data_dir='/mini_data',
         size='mini',
         gdrive_mounted='f'):
    utils.heading('SETUP')
    if size == 'mini':
        args_dict = {
            'warm_up_steps': 50.0,
            'train_batch_size': 10,
            'test_batch_size': 10,
            'data_dir': data_dir,
            'mode': mode,
            'model_name': model_name,
            'eval_dev_every': 75,
            'eval_train_every': 150,
            'save_model_every': 10
        }
        config = configure.Config(**args_dict)
    else:
        config = configure.Config(data_dir=data_dir,
                                  mode=mode,
                                  model_name=model_name)
    config.write()
    with tf.Graph().as_default() as graph:
        model_trainer = trainer.Trainer(config)
        summary_writer = tf.summary.FileWriter(config.summaries_dir)
        checkpoints_saver = tf.train.Saver(max_to_keep=1)
        best_model_saver = tf.train.Saver(max_to_keep=1)
        init_op = tf.global_variables_initializer()
        graph.finalize()
        with tf.Session() as sess:
            sess.run(init_op)
            progress = training_progress.TrainingProgress(
                config, sess, checkpoints_saver, best_model_saver,
                config.mode == 'train')
            utils.log()
            if config.mode == 'train':
                utils.heading('START TRAINING ({:})'.format(config.model_name))
                model_trainer.train(sess, progress, summary_writer)
            elif config.mode == 'eval':
                utils.heading('RUN EVALUATION ({:})'.format(config.model_name))
                progress.best_model_saver.restore(
                    sess, tf.train.latest_checkpoint(config.checkpoints_dir))
                op_preds = model_trainer.evaluate_all_tasks(
                    sess, summary_writer, None)
                return op_preds
            else:
                raise ValueError('Mode must be "train" or "eval"')
Exemplo n.º 3
0
 def _get_examples(self, split):
     word_vocab = embeddings.get_word_vocab(self._config)
     char_vocab = embeddings.get_char_vocab()
     examples = [
         TaggingExample(self._config, self._is_token_level, words, tags,
                        word_vocab, char_vocab, self.label_mapping,
                        self._task_name)
         for words, tags in self.get_labeled_sentences(split)
     ]
     if self._config.train_set_percent < 100:
         utils.log('using reduced train set ({:}%)'.format(
             self._config.train_set_percent))
         random.shuffle(examples)
         examples = examples[:int(
             len(examples) * self._config.train_set_percent / 100.0)]
     return examples
Exemplo n.º 4
0
 def save_if_best_dev_model(self, sess, global_step):
     best_avg_score = 0
     for i, results in enumerate(self.history):
         if any("train" in metric for metric, value in results):
             continue
         total, count = 0, 0
         for metric, value in results:
             if "f1" in metric or "las" in metric or "accuracy" in metric:
                 total += value
                 count += 1
         avg_score = total / count
         if avg_score >= best_avg_score:
             best_avg_score = avg_score
             if i == len(self.history) - 1:
                 utils.log("New best model! Saving...")
                 self.best_model_saver.save(
                     sess,
                     self.config.best_model_checkpoint,
                     global_step=global_step)
Exemplo n.º 5
0
    def train(self, sess, progress, summary_writer):
        heading = lambda s: utils.heading(s, '(' + self._config.model_name + ')')
        trained_on_sentences = 0
        start_time = time.time()
        unsupervised_loss_total, unsupervised_loss_count = 0, 0
        supervised_loss_total, supervised_loss_count = 0, 0
        for mb in self._get_training_mbs(progress.unlabeled_data_reader):
            if mb.task_name != 'unlabeled':
                loss = self._model.train_labeled(sess, mb)
                supervised_loss_total += loss
                supervised_loss_count += 1

            if mb.task_name == 'unlabeled':
                self._model.run_teacher(sess, mb)
                loss = self._model.train_unlabeled(sess, mb)
                unsupervised_loss_total += loss
                unsupervised_loss_count += 1
                mb.teacher_predictions.clear()

            trained_on_sentences += mb.size
            global_step = self._model.get_global_step(sess)

            if global_step % self._config.print_every == 0:
                utils.log('step {:} - '
                          'supervised loss: {:.2f} - '
                          'unsupervised loss: {:.2f} - '
                          '{:.1f} sentences per second'.format(
                    global_step,
                    supervised_loss_total / max(1, supervised_loss_count),
                    unsupervised_loss_total / max(1, unsupervised_loss_count),
                    trained_on_sentences / (time.time() - start_time)))
                unsupervised_loss_total, unsupervised_loss_count = 0, 0
                supervised_loss_total, supervised_loss_count = 0, 0

            if global_step % self._config.eval_dev_every == 0:
                heading('EVAL ON DEV')
                self.evaluate_all_tasks(sess, summary_writer, progress.history)
                progress.save_if_best_dev_model(sess, global_step)
                utils.log()

            if global_step % self._config.eval_train_every == 0:
                heading('EVAL ON TRAIN')
                self.evaluate_all_tasks(sess, summary_writer, progress.history, True)
                utils.log()

            if global_step % self._config.save_model_every == 0:
                heading('CHECKPOINTING MODEL')
                progress.write(sess, global_step)
                utils.log()
Exemplo n.º 6
0
    def _evaluate_task(self, sess, task, summary_writer, train_set):
        scorer = task.get_scorer()
        data = task.train_set if train_set else task.val_set
        output_predictions = list()
        mapping_dict = utils.load_pickle(self._config.preprocessed_data_topdir + "/chunk_BIOES_label_mapping.pkl")
        inv_map = {v: k for k, v in mapping_dict.items()}
        for i, mb in enumerate(data.get_minibatches(self._config.test_batch_size)):
            loss, batch_preds = self._model.test(sess, mb)
            normal_list = batch_preds.tolist()
            for j in range(len(normal_list)):
                tokens = str((mb.examples[j])).split()
                preds = [inv_map[q] for q in normal_list[j]]
                zipped_pred = list(zip(tokens, preds))
                output_predictions.append(zipped_pred)
            scorer.update(mb.examples, batch_preds, loss)

        results = scorer.get_results(task.name +
                                     ('_train_' if train_set else '_dev_'))
        utils.log(task.name.upper() + ': ' + scorer.results_str())
        write_summary(summary_writer, results,
                      global_step=self._model.get_global_step(sess))
        return output_predictions, results
Exemplo n.º 7
0
    def __init__(self, config):
        self._config = config
        self.tasks = [task_definitions.get_task(self._config, task_name)
                      for task_name in self._config.task_names]

        utils.log('Loading Pretrained Embeddings')
        pretrained_embeddings = utils.load_pickle(self._config.word_embeddings)

        utils.log('Building Model')
        self._model = multitask_model.Model(
            self._config, pretrained_embeddings, self.tasks)
        utils.log()
Exemplo n.º 8
0
    def build(self):
        utils.log('loading pretrained embeddings from',
                  self.config.pretrained_embeddings_file)
        for special in SPECIAL_TOKENS:
            self._add_vector(special)
        for extra in _EXTRA_WORDS:
            self._add_vector(extra)
        with tf.gfile.GFile(self.config.pretrained_embeddings_file, 'r') as f:
            for i, line in enumerate(f):
                if i % 10000 == 0:
                    utils.log('on line', i)

                split = line.split()
                w = normalize_word(split[0])

                try:
                    vec = np.array(list(map(float, split[1:])),
                                   dtype='float32')
                    if vec.size != self.vector_size:
                        utils.log('vector for line', i, 'has size', vec.size,
                                  'so skipping')
                        utils.log(line[:100] + '...')
                        continue
                except:
                    utils.log('can\'t parse line', i, 'so skipping')
                    utils.log(line[:100] + '...')
                    continue
                if w not in self.vocabulary:
                    self.vocabulary[w] = len(self.vectors)
                    self.vectors.append(vec)
        utils.log('writing vectors!')
        self._write()
Exemplo n.º 9
0
def main(data_dir='./data'):
    random.seed(0)
    utils.log("BUILDING WORD VOCABULARY/EMBEDDINGS")
    for pretrained in ['glove.6B.300d.txt']:
        config = configure.Config(data_dir=data_dir,
                                  for_preprocessing=True,
                                  pretrained_embeddings=pretrained,
                                  word_embedding_size=300)
        embeddings.PretrainedEmbeddingLoader(config).build()

    utils.log("CONSTRUCTING DEV SETS")
    for task_name in ["chunk"]:
        # chunking does not come with a provided dev split, so create one by
        # selecting a random subset of the data
        config = configure.Config(data_dir=data_dir,
                                  for_preprocessing=True)
        task_data_dir = os.path.join(config.raw_data_topdir, task_name) + '/'
        train_sentences = word_level_data.TaggedDataLoader(
            config, task_name, False).get_labeled_sentences("train")
        random.shuffle(train_sentences)
        if 'mini_data' not in data_dir:
            write_sentences(task_data_dir + 'train_subset.txt', train_sentences[1500:])
            write_sentences(task_data_dir + 'dev.txt', train_sentences[:1500])
        else:
            write_sentences(task_data_dir + 'train_subset.txt', train_sentences[len(train_sentences)//4:])
            write_sentences(task_data_dir + 'dev.txt', train_sentences[:len(train_sentences)//4])

    utils.log("WRITING LABEL MAPPINGS")
    for task_name in ["chunk"]:
        for i, label_encoding in enumerate(["BIOES"]):
            config = configure.Config(data_dir=data_dir,
                                      for_preprocessing=True,
                                      label_encoding=label_encoding)
            token_level = task_name in ["ccg", "pos", "depparse"]
            loader = word_level_data.TaggedDataLoader(config, task_name, token_level)
            if token_level:
                if i != 0:
                    continue
                utils.log("WRITING LABEL MAPPING FOR", task_name.upper())
            else:
                utils.log("  Writing label mapping for", task_name.upper(),
                          label_encoding)
            utils.log(" ", len(loader.label_mapping), "classes")
            utils.write_pickle(loader.label_mapping,
                               loader.label_mapping_path)