Exemplo n.º 1
0
 def read_all(path: Path, add_eos):
     count = 0
     with IO.reader(path) as reader:
         dialog = Dialog()
         for line in reader:
             line = line.strip()
             if line:
                 parts = line.split("\t")
                 char, seq = parts[-2:]  # the last two are mandatory
                 uid = parts[0] if len(parts) > 2 else None
                 weight = float(parts[1]) if len(parts) > 3 else None
                 char, seq = int(char), [
                     int(x) for x in seq.strip().split()
                 ]
                 if add_eos and seq[-1] != EOS_TOK_IDX:
                     seq.append(EOS_TOK_IDX)
                 dialog.append(Utterance(char, seq, uid=uid, weight=weight))
             else:
                 if len(dialog) > 0:
                     yield dialog
                     count += 1
                     dialog = Dialog()
         if len(dialog) > 0:
             count += 1
             yield dialog
     log.info(f"Read {count} dialogs")
Exemplo n.º 2
0
 def write_tsv(records: Iterator[DialogRecord], path: Union[str, Path]):
     seqs = ((str(x), ' '.join(map(str, y))) for x, y in records)
     lines = (f'{x}\t{y}\n' for x, y in seqs)
     log.info(f"Storing data at {path}")
     with IO.writer(path) as f:
         for line in lines:
             f.write(line)
Exemplo n.º 3
0
    def store_model(self, step: int, model, train_score: float,
                    val_score: float, keep: int):
        """
        saves model to a given path
        :param step: step number of training
        :param model: model object itself
        :param train_score: score of model on training split
        :param val_score: score of model on validation split
        :param keep: number of good models to keep, bad models will be deleted
        :return:
        """
        # TODO: improve this by skipping the model save if the model is not good enough to be saved
        if self.read_only:
            log.warning("Ignoring the store request; experiment is readonly")
            return
        name = f'model_{step:05d}_{train_score:.6f}_{val_score:.6f}.pkl'
        path = self.model_dir / name
        log.info(f"Saving... step={step} to {path}")
        torch.save(model, str(path))

        for bad_model in self.list_models(sort='total_score',
                                          desc=False)[keep:]:
            log.info(f"Deleting bad model {bad_model} . Keep={keep}")
            os.remove(str(bad_model))

        with IO.writer(os.path.join(self.model_dir, 'scores.tsv'),
                       append=True) as f:
            cols = [
                str(step),
                datetime.now().isoformat(), name, f'{train_score:g}',
                f'{val_score:g}'
            ]
            f.write('\t'.join(cols) + '\n')
Exemplo n.º 4
0
 def write_lines(path: Union[str, Path], lines):
     count = 0
     with IO.writer(path) as out:
         for line in lines:
             count += 1
             out.write(line.strip())
             out.write("\n")
         log.info(f"Wrote {count} lines to {path}")
Exemplo n.º 5
0
 def write_dialogs(dialogs: Iterator[Dialog], out: Path, dialog_sep='\n'):
     count = 0
     with IO.writer(out) as outh:
         for dialog in dialogs:
             count += 1
             for utter in dialog.chat:
                 if utter.uid:
                     outh.write(f'{utter.uid}\t')
                 if utter.weight:
                     outh.write(f'{utter.weight:g}\t')
                 text = " ".join(map(str, utter.text))
                 outh.write(f'{utter.char}\t{text}\n')
             outh.write(dialog_sep)
     log.info(f"Wrote {count} recs to {out}")
Exemplo n.º 6
0
def read_msg_resp(path: str):
    def _read(rdr):
        recs = (x.strip() for x in rdr)
        recs = (x for x in recs if x)
        recs = (x.split('\t') for x in recs)
        recs = (x for x in recs if len(x) == 2)
        recs = list(recs)
        msgs = [x[0] for x in recs]
        resps = [x[1] for x in recs]
        return msgs, resps

    if type(path) is str:
        with IO.reader(path) as r:
            return _read(r)
    else:
        return _read(path)
Exemplo n.º 7
0
    def __init__(self,
                 inp: Union[str, Path, TextIO, Iterator[str]],
                 text_field: Field = None,
                 char_field: LookupField = None,
                 max_seq_len: int = 100,
                 add_eos=True):
        """

        :param inp: dialog seq file
        :param text_field: provide this field if you want to map text to word ids.
         by default it splits words by white space and return words as sequence
        :param char_field: provide this field if you want to map character name to id.
        """
        if type(inp) is str:
            inp = Path(inp)
        if isinstance(inp, Path):
            assert inp.exists()
            inp = IO.reader(inp).open()
        self.reader = inp
        self.text_field = text_field
        self.char_field = char_field
        self.max_seq_len = max_seq_len
        self.add_eos = add_eos
        self.num_cols = 0
Exemplo n.º 8
0
 def read_raw_lines(dialog_path: Union[str, Path]) -> Iterator[RawRecord]:
     with IO.reader(dialog_path) as lines:
         recs = (line.split("\t")[-2:] for line in lines)
         recs = ((char.strip(), dialog.strip()) for char, dialog in recs)
         recs = ((char, dialog) for char, dialog in recs if char and dialog)
         yield from recs
Exemplo n.º 9
0
 def store_config(self):
     with IO.writer(self._config_file) as fp:
         return yaml.dump(self.config, fp, default_flow_style=False)
Exemplo n.º 10
0
 def _read_char_names():
     with IO.reader(path) as inp:
         for line in inp:
             parts = line.strip().split('\t')
             if len(parts) >= 2:
                 yield parts[-2]
Exemplo n.º 11
0
def read_lines(path: Union[str, Path]):
    with IO.reader(path) as f:
        lines = f.readlines()
        lines = [l.strip() for l in lines]
        return lines
Exemplo n.º 12
0
def read_tsv(path: str):
    assert os.path.exists(path)
    with IO.reader(path) as f:
        yield from (line.split('\t') for line in f)
Exemplo n.º 13
0
def read_lines(path):
    if type(path) is str:
        with IO.reader(path) as reader:
            yield from read_lines_reader(reader)
    else:
        return read_lines_reader(path)