def set_embeddings(self): # Read word embeddings. if not self.opt.get('embedding_file'): logger.warning( '[ WARNING: No embeddings provided. ' 'Keeping random initialization. ]' ) return logger.info('[ Loading pre-trained embeddings ]') embeddings = load_embeddings(self.opt, self.word_dict) logger.info('[ Num embeddings = %d ]' % embeddings.size(0)) # Sanity check dimensions new_size = embeddings.size() old_size = self.network.embedding.weight.size() if new_size[1] != old_size[1]: raise RuntimeError('Embedding dimensions do not match.') if new_size[0] != old_size[0]: logger.warning( '[ WARNING: Number of embeddings changed (%d->%d) ]' % (old_size[0], new_size[0]) ) # Swap weights self.network.embedding.weight.data = embeddings # If partially tuning the embeddings, keep the old values if self.opt['tune_partial'] > 0: if self.opt['tune_partial'] + 2 < embeddings.size(0): fixed_embedding = embeddings[self.opt['tune_partial'] + 2 :] self.network.fixed_embedding = fixed_embedding
def text2spvec(self, query): """Create a sparse tfidf-weighted word vector from query. tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) """ # Get hashed ngrams words = self.parse(utils.normalize(query)) wids = [utils.hash(w, self.hash_size) for w in words] if len(wids) == 0: if self.strict: raise RuntimeError('No valid word in: %s' % query) else: logger.warning('No valid word in: %s' % query) return sp.csr_matrix((1, self.hash_size)) # Count TF wids_unique, wids_counts = np.unique(wids, return_counts=True) tfs = np.log1p(wids_counts) # Count IDF Ns = self.doc_freqs[wids_unique] idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) idfs[idfs < 0] = 0 # TF-IDF data = np.multiply(tfs, idfs) # One row, sparse csr matrix indptr = np.array([0, len(wids_unique)]) spvec = sp.csr_matrix((data, wids_unique, indptr), shape=(1, self.hash_size)) return spvec
def __init__(self, **kwargs): """ Args: annotators: None or empty set (only tokenizes). substitutions: if true, normalizes some token types (e.g. quotes). """ self._regexp = regex.compile( '(?P<digit>%s)|(?P<title>%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % ( self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, self.NON_WS, ), flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, ) if len(kwargs.get('annotators', {})) > 0: logger.warning('%s only tokenizes! Skipping annotators: %s' % (type(self).__name__, kwargs.get('annotators'))) self.annotators = set() self.substitutions = kwargs.get('substitutions', True)
def save(self, filename): params = { 'state_dict': {'network': self.network.state_dict()}, 'feature_dict': self.feature_dict, 'config': self.opt, } try: torch.save(params, filename) except BaseException: logger.warning('[ WARN: Saving failed... continuing anyway. ]')
def __init__(self, **kwargs): """ Args: annotators: None or empty set (only tokenizes). """ self._regexp = regex.compile( '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, ) if len(kwargs.get('annotators', {})) > 0: logger.warning('%s only tokenizes! Skipping annotators: %s' % (type(self).__name__, kwargs.get('annotators'))) self.annotators = set()
def __init__(self, opt: Opt, shared: TShared = None): opt = copy.deepcopy(opt) self.delimiter = opt.get('delimiter', '\n') split_type = SplitType(opt.get("cmu_dog_split_type")) if split_type == SplitType.ORIGINAL: logger.warning( "`original` split type contains duplicate conversations across train, " "valid, and test. See https://github.com/festvox/datasets-CMU_DoG/issues/2 " "for more detail.") opt['datafile'] = _datafile(split=DatatypeHelper.fold(opt['datatype']), split_type=split_type) super().__init__(opt, shared) if shared: self.rare_word_f1 = shared['rare_word_f1'] else: self.rare_word_f1 = _build_rare_word_f1(opt)
def build(opt): version = MSC_DATASETS_VERSION # create particular instance of dataset depending on flags.. dpath = get_msc_dir_path(opt) if not build_data.built(dpath, version): logger.warning('[build data: ' + dpath + ']') if build_data.built(dpath): # An older version exists, so remove these outdated files. build_data.remove_dir(dpath) build_data.make_dir(dpath) for downloadable_file in RESOURCES: downloadable_file.download_file(dpath) # Mark the data as built. build_data.mark_done(dpath, version) return dpath
def _datafile(split: str, split_type: SplitType) -> str: """ Returns the filename, e.g. train.json. """ if split_type == SplitType.ORIGINAL: return f"{split}.json" if split_type == SplitType.SEEN: if 'test' in split: return "test_seen_split_seen_unseen.json" return f"{split}_split_seen_unseen.json" if split_type == SplitType.UNSEEN: if "test" not in split: logger.warning( "Trying to use a non-test dataset with split `unseen`. `unseen` " "only returns the unseen test set. Are you sure you didn't mean to " "use `seen` here?") return "test_unseen_split_seen_unseen.json" return f"{split}_deduped.json"
def __init__(self, opt, shared=None): self.summary_num_turns = opt['summary_num_turns'] assert (self.summary_num_turns < 0 or self.summary_num_turns % 2 == 0), "Please choose an even number for turns" self.session_id = opt['session_id'] assert opt[ 'session_id'] <= 4, f"No data beyong session {opt['session_id']}!" assert (opt['session_id'] <= 3 or 'train' not in opt['datatype'] ), f"No train data beyong session {opt['session_id']}!" self.nopersona_subsampling_weight = opt['nopersona_subsampling_weight'] if 'test' in opt['datatype']: logger.warning( f'WARNING: Do not subsampling for {opt["datatype"]}') self.nopersona_subsampling_weight = 1 assert (self.nopersona_subsampling_weight >= 0 and self.nopersona_subsampling_weight <= 1 ), "invalid subsampling weight" dpath = build(opt) opt['datafile'] = get_sessionbase_dir_path(opt, dpath, 'msc_personasummary') self.id = f'msc_personasummary_{self.session_id}' super().__init__(opt, shared)
def setup_data(self, datafile): print('loading: ' + datafile) if self.datatype.startswith('train'): path_to_open = os.path.join(datafile, 'train.txt') elif self.datatype.startswith('valid'): path_to_open = os.path.join(datafile, 'valid.txt') else: path_to_open = os.path.join(datafile, 'test.txt') with PathManager.open(path_to_open) as f: raw_data = [json.loads(line.strip()) for line in f] data = [] label_speaker_id_range = {} predicted_summary_dict = {} if self.use_predicted_summary: is_session_level = not ('utt_' in self.previous_persona_type) predsum_path = get_predicted_summary_path(self.msc_dpath, is_session_level) logger.warning(f"use the predicted summary from {predsum_path}") with PathManager.open(predsum_path) as jsonfile: predicted_summary_dict = json.load(jsonfile) def _get_time_gap(time_num, time_unit, time_token=""): time_gap = str(time_num) + ' ' + time_unit return f'{time_token} {time_gap}' if len( time_token) > 0 else time_gap def _compile_persona_dialog_input(dialog, personas, previous_dialogs, label_speaker_id): new_dialog = copy.deepcopy(dialog) new_previous_dialogs = copy.deepcopy(previous_dialogs) your_persona = "" partner_persona = "" if label_speaker_id == 'self': your_persona = '\n'.join( [f'your persona: {x}' for x in personas[1]]) partner_persona = '\n'.join( [f"partner's persona: {x}" for x in personas[0]]) elif label_speaker_id == 'their': your_persona = '\n'.join( [f'your persona: {x}' for x in personas[0]]) partner_persona = '\n'.join( [f"partner's persona: {x}" for x in personas[1]]) for prev_dialog in new_previous_dialogs: prev_dialog['dialog'].insert(0, {"text": DUMMY_TEXT}) if len(prev_dialog['dialog']) % 2 == 1 and ( self.history_person_tokens is None): prev_dialog['dialog'].append({"text": DUMMY_TEXT}) new_dialog.insert(0, {"text": DUMMY_TEXT}) return your_persona, partner_persona, new_dialog, new_previous_dialogs for dialog_dict in raw_data: initial_data_id = dialog_dict['metadata']['initial_data_id'] if self.label_speaker_id == 'both': label_speaker_id_range = ['their', 'self'] else: label_speaker_id_range = [self.label_speaker_id] for label_speaker_id in label_speaker_id_range: if self.use_predicted_summary: personas_to_complie = predicted_summary_dict[str( self.session_id - 1)][initial_data_id] elif self.previous_persona_type.startswith('init'): personas_to_complie = dialog_dict['init_personas'] else: personas_to_complie = dialog_dict['personas'] ( your_persona, partner_persona, new_dialog, new_previous_dialogs, ) = _compile_persona_dialog_input( dialog_dict['dialog'], personas_to_complie, dialog_dict['previous_dialogs'], label_speaker_id, ) previous_sessions_msgs = [] if self.previous_persona_type == 'raw_history': for d_id in range(len(new_previous_dialogs)): previous_dialog_msg = [ x['text'] for x in new_previous_dialogs[d_id]['dialog'] ] if self.history_person_tokens: previous_dialog_msg = [ self.history_person_tokens[i % 2] + ' ' + text for i, text in enumerate(previous_dialog_msg) if text != DUMMY_TEXT ] if self.history_time_gaps_token: time_gap_i = _get_time_gap( new_previous_dialogs[d_id]['time_num'], new_previous_dialogs[d_id]['time_unit'], time_token=self.history_time_gaps_token, ) previous_sessions_msgs.append( '\n'.join(previous_dialog_msg + [time_gap_i])) else: previous_sessions_msgs.append( '\n'.join(previous_dialog_msg)) if self.previous_session_delimiter is not None: previous_sessions_msgs = [ val for pair in zip( previous_sessions_msgs, [self.previous_session_delimiter] * len(previous_sessions_msgs), ) for val in pair ] previous_sessions_msgs = '\n'.join(previous_sessions_msgs) episode = [] for i in range(0, len(new_dialog) - 1, 2): text = new_dialog[i]['text'] partner_persona_one_line = partner_persona.replace( '\n', '').split("partner's persona: ") your_persona_one_line = your_persona.replace( '\n', '').split("your persona: ") action = { 'id': self.id, 'text': self.normalize_replies(text), 'labels': [self.normalize_replies(new_dialog[i + 1]['text'])], 'session_id': self.session_id, 'initial_data_id': initial_data_id, 'personas': f'{partner_persona}\n{your_persona}', 'personas_one_line': f"partner's persona: {' '.join(partner_persona_one_line)}\nyour persona: {' '.join(your_persona_one_line)}", } if i == 0: action.update({ 'time_num': dialog_dict['previous_dialogs'][-1]['time_num'], 'time_unit': dialog_dict['previous_dialogs'][-1]['time_unit'], }) episode.append(action) if self.session_openning: break persona_context_str = "" if 'self' in self.previous_persona_type: persona_context_str = your_persona elif 'their' in self.previous_persona_type: persona_context_str = partner_persona elif 'both' in self.previous_persona_type: if self.your_persona_first: persona_context_str = ( (your_persona + '\n') if len(your_persona) > 0 else "") + partner_persona else: persona_context_str = ( (partner_persona + '\n') if len(partner_persona) > 0 else "") + your_persona elif self.previous_persona_type == 'raw_history': persona_context_str = previous_sessions_msgs if self.include_last_time_gap: time_gap = _get_time_gap( dialog_dict['previous_dialogs'][-1]['time_num'], dialog_dict['previous_dialogs'][-1]['time_unit'], ) persona_context_str = ( (persona_context_str + '\n') if len(persona_context_str) > 0 else "") + f'[{time_gap}]' if persona_context_str and len(persona_context_str) > 0: episode[0]['text'] = persona_context_str + '\n' + episode[ 0]['text'] data.append(episode) for episode in data: start_idx = 0 for i, turn in enumerate(episode): yield Message(turn), i == start_idx