def read_all(path: Path, add_eos): count = 0 with IO.reader(path) as reader: dialog = Dialog() for line in reader: line = line.strip() if line: parts = line.split("\t") char, seq = parts[-2:] # the last two are mandatory uid = parts[0] if len(parts) > 2 else None weight = float(parts[1]) if len(parts) > 3 else None char, seq = int(char), [ int(x) for x in seq.strip().split() ] if add_eos and seq[-1] != EOS_TOK_IDX: seq.append(EOS_TOK_IDX) dialog.append(Utterance(char, seq, uid=uid, weight=weight)) else: if len(dialog) > 0: yield dialog count += 1 dialog = Dialog() if len(dialog) > 0: count += 1 yield dialog log.info(f"Read {count} dialogs")
def write_tsv(records: Iterator[DialogRecord], path: Union[str, Path]): seqs = ((str(x), ' '.join(map(str, y))) for x, y in records) lines = (f'{x}\t{y}\n' for x, y in seqs) log.info(f"Storing data at {path}") with IO.writer(path) as f: for line in lines: f.write(line)
def store_model(self, step: int, model, train_score: float, val_score: float, keep: int): """ saves model to a given path :param step: step number of training :param model: model object itself :param train_score: score of model on training split :param val_score: score of model on validation split :param keep: number of good models to keep, bad models will be deleted :return: """ # TODO: improve this by skipping the model save if the model is not good enough to be saved if self.read_only: log.warning("Ignoring the store request; experiment is readonly") return name = f'model_{step:05d}_{train_score:.6f}_{val_score:.6f}.pkl' path = self.model_dir / name log.info(f"Saving... step={step} to {path}") torch.save(model, str(path)) for bad_model in self.list_models(sort='total_score', desc=False)[keep:]: log.info(f"Deleting bad model {bad_model} . Keep={keep}") os.remove(str(bad_model)) with IO.writer(os.path.join(self.model_dir, 'scores.tsv'), append=True) as f: cols = [ str(step), datetime.now().isoformat(), name, f'{train_score:g}', f'{val_score:g}' ] f.write('\t'.join(cols) + '\n')
def write_lines(path: Union[str, Path], lines): count = 0 with IO.writer(path) as out: for line in lines: count += 1 out.write(line.strip()) out.write("\n") log.info(f"Wrote {count} lines to {path}")
def write_dialogs(dialogs: Iterator[Dialog], out: Path, dialog_sep='\n'): count = 0 with IO.writer(out) as outh: for dialog in dialogs: count += 1 for utter in dialog.chat: if utter.uid: outh.write(f'{utter.uid}\t') if utter.weight: outh.write(f'{utter.weight:g}\t') text = " ".join(map(str, utter.text)) outh.write(f'{utter.char}\t{text}\n') outh.write(dialog_sep) log.info(f"Wrote {count} recs to {out}")
def read_msg_resp(path: str): def _read(rdr): recs = (x.strip() for x in rdr) recs = (x for x in recs if x) recs = (x.split('\t') for x in recs) recs = (x for x in recs if len(x) == 2) recs = list(recs) msgs = [x[0] for x in recs] resps = [x[1] for x in recs] return msgs, resps if type(path) is str: with IO.reader(path) as r: return _read(r) else: return _read(path)
def __init__(self, inp: Union[str, Path, TextIO, Iterator[str]], text_field: Field = None, char_field: LookupField = None, max_seq_len: int = 100, add_eos=True): """ :param inp: dialog seq file :param text_field: provide this field if you want to map text to word ids. by default it splits words by white space and return words as sequence :param char_field: provide this field if you want to map character name to id. """ if type(inp) is str: inp = Path(inp) if isinstance(inp, Path): assert inp.exists() inp = IO.reader(inp).open() self.reader = inp self.text_field = text_field self.char_field = char_field self.max_seq_len = max_seq_len self.add_eos = add_eos self.num_cols = 0
def read_raw_lines(dialog_path: Union[str, Path]) -> Iterator[RawRecord]: with IO.reader(dialog_path) as lines: recs = (line.split("\t")[-2:] for line in lines) recs = ((char.strip(), dialog.strip()) for char, dialog in recs) recs = ((char, dialog) for char, dialog in recs if char and dialog) yield from recs
def store_config(self): with IO.writer(self._config_file) as fp: return yaml.dump(self.config, fp, default_flow_style=False)
def _read_char_names(): with IO.reader(path) as inp: for line in inp: parts = line.strip().split('\t') if len(parts) >= 2: yield parts[-2]
def read_lines(path: Union[str, Path]): with IO.reader(path) as f: lines = f.readlines() lines = [l.strip() for l in lines] return lines
def read_tsv(path: str): assert os.path.exists(path) with IO.reader(path) as f: yield from (line.split('\t') for line in f)
def read_lines(path): if type(path) is str: with IO.reader(path) as reader: yield from read_lines_reader(reader) else: return read_lines_reader(path)