def _punctuation_ids(vocab_path): vocab = Vocabulary(utils.load_cpickle(vocab_path)) return set( vocab.get(w) for w in vocab if w in [ '!', '...', '``', '{', '}', '(', ')', '[', ']', '--', '-', ',', '.', "''", '`', ';', ':', '?' ])
def __init__(self, config): self._config = config self.tasks = [ task_definitions.get_task(self._config, task_name) for task_name in self._config.task_names ] utils.log('Loading Pretrained Embeddings') pretrained_embeddings = utils.load_cpickle( self._config.word_embeddings) utils.log('Building Model') self._model = multitask_model.Model(self._config, pretrained_embeddings, self.tasks) utils.log()
def label_mapping(self): if not self._config.for_preprocessing: return utils.load_cpickle(self.label_mapping_path) tag_counts = collections.Counter() train_tags = set() for split in ['train', 'dev', 'test']: for words, tags in self.get_labeled_sentences(split): if not self._is_token_level: span_labels = tagging_utils.get_span_labels(tags) tags = tagging_utils.get_tags(span_labels, len(words), self._config.label_encoding) for tag in tags: if self._task_name == 'depparse': tag = tag.split('-')[1] tag_counts[tag] += 1 if split == 'train': train_tags.add(tag) if self._task_name == 'ccg': # for CCG, there are tags in the test sets that aren't in the train set # all tags not in the train set get mapped to a special label # the model will never predict this label because it never sees it in the # training set not_in_train_tags = [] for tag, count in tag_counts.items(): if tag not in train_tags: not_in_train_tags.append(tag) label_mapping = { label: i for i, label in enumerate( sorted( filter(lambda t: t not in not_in_train_tags, tag_counts.keys()))) } n = len(label_mapping) for tag in not_in_train_tags: label_mapping[tag] = n else: labels = sorted(tag_counts.keys()) if self._task_name == 'depparse': labels.remove('root') labels.insert(0, 'root') label_mapping = {label: i for i, label in enumerate(labels)} return label_mapping
def __init__(self, config, sess, checkpoint_saver, best_model_saver, restore_if_possible=True): self.config = config self.checkpoint_saver = checkpoint_saver self.best_model_saver = best_model_saver tf.gfile.MakeDirs(config.checkpoints_dir) if restore_if_possible and tf.gfile.Exists(config.progress): history, current_file, current_line = utils.load_cpickle( config.progress, memoized=False) self.history = history self.unlabeled_data_reader = unlabeled_data.UnlabeledDataReader( config, current_file, current_line) utils.log("Continuing from global step", dict(self.history[-1])["step"], "(lm1b file {:}, line {:})".format(current_file, current_line)) self.checkpoint_saver.restore(sess, tf.train.latest_checkpoint( self.config.checkpoints_dir)) else: utils.log("No previous checkpoint found - starting from scratch") self.history = [] self.unlabeled_data_reader = ( unlabeled_data.UnlabeledDataReader(config))
def get_word_embeddings(config): return utils.load_cpickle(config.word_embeddings)
def get_word_vocab(config): return Vocabulary(utils.load_cpickle(config.word_vocabulary))