예제 #1
0
def _punctuation_ids(vocab_path):
    vocab = Vocabulary(utils.load_cpickle(vocab_path))
    return set(
        vocab.get(w) for w in vocab if w in [
            '!', '...', '``', '{', '}', '(', ')', '[', ']', '--', '-', ',',
            '.', "''", '`', ';', ':', '?'
        ])
예제 #2
0
    def __init__(self, config):
        self._config = config
        self.tasks = [
            task_definitions.get_task(self._config, task_name)
            for task_name in self._config.task_names
        ]

        utils.log('Loading Pretrained Embeddings')
        pretrained_embeddings = utils.load_cpickle(
            self._config.word_embeddings)

        utils.log('Building Model')
        self._model = multitask_model.Model(self._config,
                                            pretrained_embeddings, self.tasks)
        utils.log()
예제 #3
0
    def label_mapping(self):
        if not self._config.for_preprocessing:
            return utils.load_cpickle(self.label_mapping_path)

        tag_counts = collections.Counter()
        train_tags = set()
        for split in ['train', 'dev', 'test']:
            for words, tags in self.get_labeled_sentences(split):
                if not self._is_token_level:
                    span_labels = tagging_utils.get_span_labels(tags)
                    tags = tagging_utils.get_tags(span_labels, len(words),
                                                  self._config.label_encoding)
                for tag in tags:
                    if self._task_name == 'depparse':
                        tag = tag.split('-')[1]
                    tag_counts[tag] += 1
                    if split == 'train':
                        train_tags.add(tag)
        if self._task_name == 'ccg':
            # for CCG, there are tags in the test sets that aren't in the train set
            # all tags not in the train set get mapped to a special label
            # the model will never predict this label because it never sees it in the
            # training set
            not_in_train_tags = []
            for tag, count in tag_counts.items():
                if tag not in train_tags:
                    not_in_train_tags.append(tag)
            label_mapping = {
                label: i
                for i, label in enumerate(
                    sorted(
                        filter(lambda t: t not in not_in_train_tags,
                               tag_counts.keys())))
            }
            n = len(label_mapping)
            for tag in not_in_train_tags:
                label_mapping[tag] = n
        else:
            labels = sorted(tag_counts.keys())
            if self._task_name == 'depparse':
                labels.remove('root')
                labels.insert(0, 'root')
            label_mapping = {label: i for i, label in enumerate(labels)}
        return label_mapping
예제 #4
0
    def __init__(self, config, sess, checkpoint_saver, best_model_saver,
                 restore_if_possible=True):
        self.config = config
        self.checkpoint_saver = checkpoint_saver
        self.best_model_saver = best_model_saver

        tf.gfile.MakeDirs(config.checkpoints_dir)
        if restore_if_possible and tf.gfile.Exists(config.progress):
            history, current_file, current_line = utils.load_cpickle(
                config.progress, memoized=False)
            self.history = history
            self.unlabeled_data_reader = unlabeled_data.UnlabeledDataReader(
                config, current_file, current_line)
            utils.log("Continuing from global step", dict(self.history[-1])["step"],
                      "(lm1b file {:}, line {:})".format(current_file, current_line))
            self.checkpoint_saver.restore(sess, tf.train.latest_checkpoint(
                self.config.checkpoints_dir))
        else:
            utils.log("No previous checkpoint found - starting from scratch")
            self.history = []
            self.unlabeled_data_reader = (
                unlabeled_data.UnlabeledDataReader(config))
예제 #5
0
def get_word_embeddings(config):
    return utils.load_cpickle(config.word_embeddings)
예제 #6
0
def get_word_vocab(config):
    return Vocabulary(utils.load_cpickle(config.word_vocabulary))