コード例 #1
0
    def write(cls, path, records: Iterator[ParallelSeqRecord]):
        if path.exists():
            log.warning(f"Overwriting {path} with new records")
            os.remove(str(path))
        maybe_tmp = IO.maybe_tmpfs(path)
        log.info(f'Creating {maybe_tmp}')
        conn = sqlite3.connect(str(maybe_tmp))
        cur = conn.cursor()
        cur.execute(cls.TABLE_STATEMENT)
        cur.execute(cls.INDEX_X_LEN)
        cur.execute(cls.INDEX_Y_LEN)
        cur.execute(f"PRAGMA user_version = {cls.CUR_VERSION};")

        count = 0
        for x_seq, y_seq in records:
            # use numpy. its a lot efficient
            if not isinstance(x_seq, np.ndarray):
                x_seq = np.array(x_seq, dtype=np.int32)
            if y_seq is not None and not isinstance(y_seq, np.ndarray):
                y_seq = np.array(y_seq, dtype=np.int32)
            values = (x_seq.tobytes(),
                      None if y_seq is None else y_seq.tobytes(), len(x_seq),
                      len(y_seq) if y_seq is not None else -1)
            cur.execute(cls.INSERT_STMT, values)
            count += 1
        cur.close()
        conn.commit()
        if maybe_tmp != path:
            # bring the file back to original location where it should be
            IO.copy_file(maybe_tmp, path)
        log.info(f"stored {count} rows in {path}")
コード例 #2
0
ファイル: fork.py プロジェクト: MGheini/rtg
def fork_experiment(from_exp: Path, to_exp: Path, conf: bool, vocab: bool,
                    data: bool):
    assert from_exp.exists()
    log.info(f'Fork: {str(from_exp)} → {str(to_exp)}')
    if not to_exp.exists():
        log.info(f"Create dir {str(to_exp)}")
        to_exp.mkdir(parents=True)
    if conf:
        conf_file = to_exp / 'conf.yml'
        IO.maybe_backup(conf_file)
        IO.copy_file(from_exp / 'conf.yml', conf_file)
    if data:
        to_data_dir = (to_exp / 'data')
        from_data_dir = from_exp / 'data'
        if to_data_dir.is_symlink():
            log.info(
                f"removing the existing data link: {to_data_dir.resolve()}")
            to_data_dir.unlink()
        assert not to_data_dir.exists()
        assert from_data_dir.exists()
        log.info(f"link {to_data_dir} → {from_data_dir}")
        to_data_dir.symlink_to(from_data_dir.resolve())
        (to_exp / '_PREPARED').touch(exist_ok=True)
    if not data and vocab:  # just the vocab

        Experiment(from_exp, read_only=True).copy_vocabs(
            Experiment(to_exp, config={'Not': 'Empty'}, read_only=True))
コード例 #3
0
 def read_raw_parallel_lines(src_path: Union[str, Path], tgt_path: Union[str, Path]) \
         -> Iterator[RawRecord]:
     with IO.reader(src_path) as src_lines, IO.reader(
             tgt_path) as tgt_lines:
         # if you get an exception here --> files have un equal number of lines
         recs = ((src.strip(), tgt.strip())
                 for src, tgt in zip_longest(src_lines, tgt_lines))
         recs = ((src, tgt) for src, tgt in recs if src and tgt)
         yield from recs
コード例 #4
0
ファイル: dummy.py プロジェクト: MGheini/rtg
def write_parallel(data, src_file, tgt_file):
    count = 0
    with IO.writer(src_file) as src_f, IO.writer(tgt_file) as tgt_f:
        for src_seq, tgt_seq in data:
            src_seq = ' '.join(map(str, src_seq))
            tgt_seq = ' '.join(map(str, tgt_seq))
            src_f.write(f'{src_seq}\n')
            tgt_f.write(f'{tgt_seq}\n')
            count += 1
    log.info(f"Wrote {count} records to {src_file} and {tgt_file}")
コード例 #5
0
ファイル: pipeline.py プロジェクト: MGheini/rtg
 def moses_detokenize(self,
                      inp: Path,
                      out: Path,
                      col=0,
                      lang='en',
                      post_op=None):
     log.info(f"detok : {inp} --> {out}")
     tok_lines = IO.get_lines(inp, col=col, line_mapper=lambda x: x.split())
     with MosesDetokenizer(lang=lang) as detok:
         detok_lines = (detok(tok_line) for tok_line in tok_lines)
         if post_op:
             detok_lines = (post_op(line) for line in detok_lines)
         IO.write_lines(out, detok_lines)
コード例 #6
0
ファイル: codec.py プロジェクト: isi-nlp/rtg
    def train(cls,
              model_type: str,
              vocab_size: int,
              model_path: Union[Path, str],
              files: List[str],
              tok_coverage=0.9999,
              **kwargs):
        # Note: char_coverage is abused as subword_coverage
        hub_api = cls.load_hub_model(model_type)
        bpe = hub_api.bpe
        dicto = hub_api.task.dictionary

        freqs = coll.Counter()
        lines = IO.get_liness(*files)
        for line in tqdm(lines, mininterval=2, dynamic_ncols=True,
                         unit='line'):
            freqs.update(bpe.encode(line).split())
        total_toks = sum(freqs.values())
        log.info(f"Found {len(freqs)} bpe types and {total_toks} toks")

        freqs = list(sorted(freqs.items(), reverse=True, key=lambda x: x[1]))
        vocabulary, oovs = [], []
        cumulative = 0
        for t, f in freqs:
            if cumulative / total_toks <= tok_coverage:
                vocabulary.append((t, f))
                cumulative += f
            else:
                oovs.append((t, f))

        oovs_str = ' '.join(f'{t}:{f}' for t, f in oovs)
        log.info(f'Excluded {len(oovs)} types as OOVs.\n:{oovs_str}')
        log.info(f'Included {len(vocabulary)} types as in vocabulary; '
                 f'Coverage = {cumulative / total_toks:g}')
        # TODO: mapping should be list[int] with one on one map
        types, indices = [], {}
        for typ, new_idx in cls.reserved():
            assert len(types) == new_idx
            types.append(typ)
            old_idx = dicto.indices.get(typ, -1)
            indices[typ] = [new_idx, old_idx]

        for typ, freq in vocabulary:
            # [new index, old index]
            indices[typ] = [len(types), dicto.indices.get(typ, -1)]
            types.append(typ)

        data = {'model_id': model_type, 'mapping': indices}
        with IO.writer(model_path) as wrtr:
            yaml.dump(data, wrtr)
        return cls(model_path)
コード例 #7
0
ファイル: codec.py プロジェクト: isi-nlp/rtg
    def train(cls,
              model_type: str,
              vocab_size: int,
              model_path: str,
              files: List[str],
              no_split_toks: Optional[List[str]] = None,
              char_coverage: float = 0,
              dedup=True,
              spark=None):
        """
        :param model_type: word, char, bpe
        :param vocab_size: vocabulary size
        :param model_path: where to store vocabulary model
        :param files: text for creating vcabulary
        :param no_split_toks:
        :param char_coverage: character coverage (0, 1]. value <= 0 => default coverage
        :return:
        """
        assert not no_split_toks, 'not supported in nlcodec yet'
        from nlcodec import learn_vocab, term_freq
        kwargs = dict(char_coverage=char_coverage) if char_coverage > 0 else {}
        if not spark:
            inp = IO.get_liness(*files)
        else:
            # extract and store frequencies to this file
            stats_file = model_path + '.termfreqs'
            if not Path(stats_file).exists():
                log.info("Extracting term frequencies... ")
                paths = [f if isinstance(f, Path) else Path(f) for f in files]
                wfs, chfs, n_lines = term_freq.word_counts(paths=paths,
                                                           dedup=dedup,
                                                           spark=spark)
                log.info(
                    f"Lines = {n_lines:,}, Word Types: {len(wfs):,} Char Types:{len(chfs):,}"
                )
                stats = chfs if model_type == 'char' else wfs
                log.info(f"Writing frequencies to {stats_file}")
                with IO.writer(stats_file) as out:
                    term_freq.write_stats(stats=stats,
                                          out=out,
                                          line_count=n_lines)
                kwargs['term_freqs'] = True
            inp = IO.get_lines(stats_file, delim='\n')

        learn_vocab(inp=inp,
                    level=model_type,
                    model=model_path,
                    vocab_size=vocab_size,
                    **kwargs)
        return cls(model_path)
コード例 #8
0
    def __init__(self,
                 models: List[Path],
                 exp: Union[Path, TranslationExperiment],
                 lr: float = 1e-4,
                 smoothing=0.1):
        if isinstance(exp, Path):
            exp = TranslationExperiment(exp)
        self.w_file = exp.work_dir / f'combo-weights.yml'

        wt = None
        if self.w_file.exists():
            with IO.reader(self.w_file) as rdr:
                combo_spec = yaml.load(rdr)
            weights = combo_spec['weights']
            assert len(weights) == len(
                models)  # same models as before: no messing allowed
            model_path_strs = [str(m) for m in models]
            for m in model_path_strs:
                assert m in weights, f'{m} not found in weights file.'
            wt = [weights[str(m)] for m in model_path_strs]
            log.info(f"restoring previously stored weights {wt}")

        from rtg.module.decoder import load_models
        combo = Combo(load_models(models, exp), model_paths=models, w=wt)
        self.combo = combo.to(device)
        self.exp = exp
        self.optim = torch.optim.Adam(combo.parameters(), lr=lr)
        self.criterion = LabelSmoothing(vocab_size=combo.vocab_size,
                                        padding_idx=PAD_TOK_IDX,
                                        smoothing=smoothing)
コード例 #9
0
    def shell_pipe(cls, cmd_line, inp, out):
        """

        :param cmd_line: shell commandlines
        :param inp: input file, to read records
        :param out:  output file to store records
        :return:
        """
        log.info("Shell cmd:: {cmd_line}")
        with IO.reader(inp) as rdr, IO.writer(out) as wtr:
            proc = subprocess.Popen(cmd_line,
                                    stdin=rdr,
                                    stdout=wtr,
                                    shell=True)
            proc.wait()
        log.info("Shell cmd:: Done")
コード例 #10
0
ファイル: exp.py プロジェクト: MGheini/rtg
 def _read_vocab(path: Path) -> List[str]:
     with IO.reader(path) as rdr:
         vocab = [line.strip().split()[0] for line in rdr]
         if do_clean:
             # sentence piece starts with '▁' character
             vocab = [
                 word[1:] if word[0] == '▁' else word for word in vocab
             ]
         return vocab
コード例 #11
0
 def read_raw_mono_recs(path: Union[str, Path], truncate: bool,
                        max_len: int, tokenizer):
     with IO.reader(path) as lines:
         recs = (tokenizer(line.strip()) for line in lines if line.strip())
         if truncate:
             recs = (rec[:max_len] for rec in recs)
         else:  # Filter out longer sentences
             recs = (rec for rec in recs if 0 < len(rec) <= max_len)
         yield from recs
コード例 #12
0
 def evaluate_file(self,
                   detok_hyp: Path,
                   ref: Union[Path, List[str]],
                   lowercase=True) -> float:
     detok_lines = list(IO.get_lines(detok_hyp))
     # takes multiple refs, but here we have only one
     ref_liness = [IO.get_lines(ref) if isinstance(ref, Path) else ref]
     bleu: BLEUScore = corpus_bleu(sys_stream=detok_lines,
                                   ref_streams=ref_liness,
                                   lowercase=lowercase)
     # this should be part of new sacrebleu  release (i sent a PR ;)
     bleu_str = bleu.format()
     bleu_file = detok_hyp.with_name(detok_hyp.name +
                                     ('.lc' if lowercase else '.oc') +
                                     '.sacrebleu')
     log.info(f'BLEU {detok_hyp} : {bleu_str}')
     IO.write_lines(bleu_file, bleu_str)
     return bleu.score
コード例 #13
0
ファイル: pipeline.py プロジェクト: MGheini/rtg
 def evaluate_file(self,
                   detok_hyp: Path,
                   ref: Union[Path, List[str]],
                   lowercase=True) -> float:
     detok_lines = IO.get_lines(detok_hyp)
     # takes multiple refs, but here we have only one
     ref_liness = [IO.get_lines(ref) if isinstance(ref, Path) else ref]
     bleu: BLEU = corpus_bleu(sys_stream=detok_lines,
                              ref_streams=ref_liness,
                              lowercase=lowercase)
     # this should be part of new sacrebleu  release (i sent a PR ;)
     bleu_str = f'BLEU = {bleu.score:.2f} {"/".join(f"{p:.1f}" for p in bleu.precisions)}' \
         f' (BP = {bleu.bp:.3f} ratio = {(bleu.sys_len / bleu.ref_len):.3f}' \
         f' hyp_len = {bleu.sys_len:d} ref_len={bleu.ref_len:d})'
     bleu_file = detok_hyp.with_suffix(('.lc' if lowercase else '.oc') +
                                       '.sacrebleu')
     log.info(f'BLEU {detok_hyp} : {bleu_str}')
     IO.write_lines(bleu_file, bleu_str)
     return bleu.score
コード例 #14
0
ファイル: exp.py プロジェクト: MGheini/rtg
 def copy_vocabs(self, other):
     """
     Copies vocabulary files from self to other
     :param other: other experiment
     :return:
     """
     other: TranslationExperiment = other
     if not other.data_dir.exists():
         other.data_dir.mkdir(parents=True)
     for source, destination in [
         (self._src_field_file, other._src_field_file),
         (self._tgt_field_file, other._tgt_field_file),
         (self._shared_field_file, other._shared_field_file)
     ]:
         if source.exists():
             IO.copy_file(source, destination)
             src_txt_file = source.with_name(
                 source.name.replace('.model', '.vocab'))
             if src_txt_file.exists():
                 dst_txt_file = destination.with_name(
                     destination.name.replace('.model', '.vocab'))
                 IO.copy_file(src_txt_file, dst_txt_file)
コード例 #15
0
ファイル: exp.py プロジェクト: MGheini/rtg
    def store_model(self,
                    epoch: int,
                    model,
                    train_score: float,
                    val_score: float,
                    keep: int,
                    prefix='model',
                    keeper_sort='step'):
        """
        saves model to a given path
        :param epoch: epoch number of model
        :param model: model object itself
        :param train_score: score of model on training split
        :param val_score: score of model on validation split
        :param keep: number of good models to keep, bad models will be deleted
        :param prefix: prefix to store model. default is "model"
        :param keeper_sort: criteria for choosing the old or bad models for deletion.
            Choices: {'total_score', 'step'}
        :return:
        """
        # TODO: improve this by skipping the model save if the model is not good enough to be saved
        if self.read_only:
            log.warning("Ignoring the store request; experiment is readonly")
            return
        name = f'{prefix}_{epoch:03d}_{train_score:.6f}_{val_score:.6f}.pkl'
        path = self.model_dir / name
        log.info(f"Saving epoch {epoch} to {path}")
        torch.save(model, str(path))

        del_models = []
        if keeper_sort == 'total_score':
            del_models = self.list_models(sort='total_score',
                                          desc=False)[keep:]
        elif keeper_sort == 'step':
            del_models = self.list_models(sort='step', desc=True)[keep:]
        else:
            Exception(f'Sort criteria{keeper_sort} not understood')
        for d_model in del_models:
            log.info(
                f"Deleting model {d_model} . Keep={keep}, sort={keeper_sort}")
            os.remove(str(d_model))

        with IO.writer(os.path.join(self.model_dir, 'scores.tsv'),
                       append=True) as f:
            cols = [
                str(epoch),
                datetime.now().isoformat(), name, f'{train_score:g}',
                f'{val_score:g}'
            ]
            f.write('\t'.join(cols) + '\n')
コード例 #16
0
 def decode_eval_file(self,
                      decoder,
                      src: Union[Path, List[str]],
                      out_file: Path,
                      ref: Optional[Union[Path, List[str]]],
                      lowercase: bool = True,
                      **dec_args) -> float:
     if out_file.exists() and out_file.stat().st_size > 0 and line_count(
             out_file) == (len(src)
                           if isinstance(src, list) else line_count(src)):
         log.warning(
             f"{out_file} exists and has desired number of lines. Skipped..."
         )
     else:
         if isinstance(src, Path):
             log.info(f"decoding {src.name}")
             src = list(IO.get_lines(src))
         if isinstance(ref, Path):
             ref = list(IO.get_lines(ref))
         with IO.writer(out_file) as out:
             decoder.decode_file(src, out, **dec_args)
     detok_hyp = self.detokenize(out_file)
     if ref:
         return self.evaluate_file(detok_hyp, ref, lowercase=lowercase)
コード例 #17
0
ファイル: exp.py プロジェクト: isi-nlp/rtg
    def get_train_data(self,
                       batch_size: Union[int, Tuple[int, int]],
                       steps: int = 0,
                       sort_by='eq_len_rand_batch',
                       batch_first=True,
                       shuffle=False,
                       fine_tune=False,
                       keep_in_mem=False,
                       split_ratio: float = 0.,
                       dynamic_epoch=False):

        data_path = self.train_db if self.train_db.exists(
        ) else self.train_file
        if fine_tune:
            if not self.finetune_file.exists():
                # user may have added fine tune file later
                self._pre_process_parallel('finetune_src', 'finetune_tgt',
                                           self.finetune_file)
            log.info("Using Fine tuning corpus instead of training corpus")
            data_path = self.finetune_file

        if split_ratio > 0:
            data_path = IO.maybe_tmpfs(data_path)
            train_file = data_path.with_suffix('.db.tmp')
            file_creator = partial(self.file_creator,
                                   train_file=train_file,
                                   split_ratio=split_ratio)
            train_data = GenerativeBatchIterable(file_creator=file_creator,
                                                 batches=steps,
                                                 batch_size=batch_size,
                                                 field=self.tgt_vocab,
                                                 dynamic_epoch=dynamic_epoch,
                                                 batch_first=batch_first,
                                                 shuffle=shuffle,
                                                 sort_by=sort_by,
                                                 **self._get_batch_args())
        else:
            data = BatchIterable(data_path=data_path,
                                 batch_size=batch_size,
                                 field=self.tgt_vocab,
                                 sort_by=sort_by,
                                 batch_first=batch_first,
                                 shuffle=shuffle,
                                 **self._get_batch_args())
            train_data = LoopingIterable(data, steps)

        return train_data
コード例 #18
0
 def read_all(self) -> Iterator[IdExample]:
     with IO.reader(self.path) as lines:
         recs = (line.split('\t') for line in lines)
         for idx, rec in enumerate(recs):
             x = self._parse(rec[0].strip())
             y = self._parse(rec[1].strip()) if len(rec) > 1 else None
             if self.truncate:  # truncate long recs
                 x = x[:self.max_src_len]
                 y = y if y is None else y[:self.max_tgt_len]
             elif len(x) > self.max_src_len or (0 if y is None else
                                                len(y)) > self.max_tgt_len:
                 continue  # skip long recs
             if not x or (y is not None
                          and len(y) == 0):  # empty on one side
                 log.warning(
                     f"Ignoring an empty record  x:{len(x)}    y:{len(y)}")
                 continue
             yield IdExample(x, y, id=idx)
コード例 #19
0
ファイル: word2vec.py プロジェクト: MGheini/rtg
 def save_embeddings(self, step, train_loss, val_loss, txt=True):
     matrix = self.model.emb.weight
     vocab = self.exp.shared_vocab
     words = [vocab.id_to_piece(i) for i in range(len(vocab))]
     self.tbd.add_embedding(matrix, metadata=words, global_step=step)
     ext = 'txt.gz' if txt else 'pkl'
     path = self.exp.model_dir / f'embeddings_{step}_{train_loss:.6f}_{val_loss:.6f}.{ext}'
     log.info(f"writing  embedding after step {step} to {path}")
     if txt:
         with IO.writer(path) as w:
             w.write(f'{matrix.shape[0]} {matrix.shape[1]}\n')
             for i in range(matrix.shape[0]):
                 word = words[i]
                 vect = ' '.join(f'{x:g}' for x in matrix[i])
                 w.write(f'{word} {vect}\n')
     else:
         with path.open('wb') as f:
             data = {'words': words, 'vectors': matrix.numpy}
             pickle.dump(data, f)
コード例 #20
0
ファイル: exp.py プロジェクト: isi-nlp/rtg
 def get_combo_data(self,
                    batch_size: int,
                    steps: int = 0,
                    sort_desc=False,
                    batch_first=True,
                    shuffle=False):
     if not self.combo_file.exists():
         # user may have added fine tune file later
         self._pre_process_parallel('combo_src', 'combo_tgt',
                                    self.combo_file)
     combo_file = IO.maybe_tmpfs(self.combo_file)
     data = BatchIterable(combo_file,
                          batch_size=batch_size,
                          sort_desc=sort_desc,
                          field=self.tgt_vocab,
                          batch_first=batch_first,
                          shuffle=shuffle,
                          **self._get_batch_args())
     if steps > 0:
         data = LoopingIterable(data, steps)
     return data
コード例 #21
0
ファイル: decode.py プロジェクト: MGheini/rtg
def validate_args(args, exp: Experiment):
    if not args.pop('skip_check'):  # if --skip-check is not requested
        assert exp.has_prepared(), \
            f'Experiment dir {exp.work_dir} is not ready to train. Please run "prep" sub task'
        assert exp.has_trained(), \
            f'Experiment dir {exp.work_dir} is not ready to decode.' \
            f' Please run "train" sub task or --skip-check to ignore this'

    weights_file = exp.work_dir / 'combo-weights.yml'
    if not args.get('sys_comb') and weights_file.exists():
        log.warning("Found default combo weights, switching to combo mode")
        args['sys_comb'] = weights_file

    if args.get("sys_comb"):
        with IO.reader(args['sys_comb']) as fh:
            weights = yaml.load(fh)['weights']
            args['model_path'], args['weights'] = zip(*weights.items())
            for model in args['model_path']:
                assert Path(model).exists(), model
            assert abs(sum(args['weights']) - 1) < 1e-3, \
                f'Weights from --sys-comb file should sum to 1.0, given={args["weights"]}'
コード例 #22
0
    def train(self, steps: int, batch_size: int):
        log.info(f"Going to train for {steps}")
        batches = self.exp.get_combo_data(batch_size=batch_size, steps=steps)
        with tqdm(batches, total=steps, unit='step',
                  dynamic_ncols=True) as data_bar:
            for i, batch in enumerate(data_bar):
                batch = batch.to(device)
                y_probs = self.combo(batch)  # B x T x V
                loss = self.loss_func(y_probs,
                                      y_seqs=batch.y_seqs,
                                      norm=batch.y_toks)
                wt_str = ','.join(f'{wt:g}' for wt in self.combo.weight)
                progress_msg = f'loss={loss:g}, weights={wt_str}'
                data_bar.set_postfix_str(progress_msg, refresh=False)

        weights = dict(
            zip([str(x) for x in self.combo.model_paths],
                self.combo.model_weights.tolist()))
        log.info(f" Training finished. {weights}")
        with IO.writer(self.w_file) as wtr:
            yaml.dump(dict(weights=weights), wtr, default_flow_style=False)
コード例 #23
0
ファイル: codec.py プロジェクト: isi-nlp/rtg
    def __init__(self, path: Union[str, Path]):
        with IO.reader(path) as rdr:
            data = yaml.load(rdr)
        hub_api = self.load_hub_model(data['model_id'])
        # these are for XML-R wiz RoBERTa from fairseq  ; generalize it for other models later
        self.bpe = hub_api.bpe

        self.tok2idx = {
            tok: new_idx
            for tok, (new_idx, old_idx) in data['mapping'].items()
        }
        self.idx2tok = list(
            sorted(self.tok2idx.keys(), key=self.tok2idx.get, reverse=False))
        assert len(self.idx2tok) == len(self.tok2idx)

        for tok, idx in self.reserved():  # reserved are reserved
            assert self.tok2idx[tok] == idx
            assert self.idx2tok[idx] == tok
        self.new_idx2old_idx = {
            new_idx: old_idx
            for tok, (new_idx, old_idx) in data['mapping'].items()
        }
コード例 #24
0
def read_tsv(path: str):
    assert os.path.exists(path)
    with IO.reader(path) as f:
        yield from (line.split('\t') for line in f)
コード例 #25
0
    def tune_decoder_params(self,
                            exp: Experiment,
                            tune_src: str,
                            tune_ref: str,
                            batch_size: int,
                            trials: int = 10,
                            lowercase=True,
                            beam_size=(1, 4, 8),
                            ensemble=(1, 5, 10),
                            lp_alpha=(0.0, 0.4, 0.6),
                            suggested: List[Tuple[int, int, float]] = None,
                            **fixed_args):
        _, _, _, tune_args = inspect.getargvalues(inspect.currentframe())
        tune_args.update(fixed_args)
        ex_args = ['exp', 'self', 'fixed_args', 'batch_size', 'max_len']
        if trials == 0:
            ex_args += ['beam_size', 'ensemble', 'lp_alpha']
        for x in ex_args:
            del tune_args[x]  # exclude some args

        _, step = exp.get_last_saved_model()
        tune_dir = exp.work_dir / f'tune_step{step}'
        log.info(f"Tune dir = {tune_dir}")
        tune_dir.mkdir(parents=True, exist_ok=True)
        tune_src, tune_ref = Path(tune_src), Path(tune_ref)
        assert tune_src.exists()
        assert tune_ref.exists()
        tune_src, tune_ref = list(IO.get_lines(tune_src)), list(
            IO.get_lines(tune_ref))
        assert len(tune_src) == len(tune_ref)

        tune_log = tune_dir / 'scores.json'  # resume the tuning
        memory: Dict[Tuple, float] = {}
        if tune_log.exists():
            data = json.load(tune_log.open())
            # JSON keys cant be tuples, so they were stringified
            memory = {eval(k): v for k, v in data.items()}

        beam_sizes, ensembles, lp_alphas = [], [], []
        if suggested:
            if isinstance(suggested[0], str):
                suggested = [eval(x) for x in suggested]
            suggested = [(x[0], x[1], round(x[2], 2)) for x in suggested]
            suggested_new = [x for x in suggested if x not in memory]
            beam_sizes += [x[0] for x in suggested_new]
            ensembles += [x[1] for x in suggested_new]
            lp_alphas += [x[2] for x in suggested_new]

        new_trials = trials - len(memory)
        if new_trials > 0:
            beam_sizes += [random.choice(beam_size) for _ in range(new_trials)]
            ensembles += [random.choice(ensemble) for _ in range(new_trials)]
            lp_alphas += [
                round(random.choice(lp_alpha), 2) for _ in range(new_trials)
            ]

        # ensembling is somewhat costlier, so try minimize the model ensembling, by grouping them together
        grouped_ens = defaultdict(list)
        for b, ens, l in zip(beam_sizes, ensembles, lp_alphas):
            grouped_ens[ens].append((b, l))
        try:
            for ens, args in grouped_ens.items():
                decoder = Decoder.new(exp, ensemble=ens)
                for b_s, lp_a in args:
                    eff_batch_size = batch_size // b_s  # effective batch size
                    name = f'tune_step{step}_beam{b_s}_ens{ens}_lp{lp_a:.2f}'
                    log.info(name)
                    out_file = tune_dir / f'{name}.out.tsv'
                    score = self.decode_eval_file(decoder,
                                                  tune_src,
                                                  out_file,
                                                  tune_ref,
                                                  batch_size=eff_batch_size,
                                                  beam_size=b_s,
                                                  lp_alpha=lp_a,
                                                  lowercase=lowercase,
                                                  **fixed_args)
                    memory[(b_s, ens, lp_a)] = score
            best_params = sorted(memory.items(),
                                 key=lambda x: x[1],
                                 reverse=True)[0][0]
            return dict(zip(['beam_size', 'ensemble', 'lp_alpha'],
                            best_params)), tune_args
        finally:
            # JSON keys cant be tuples, so we stringify them
            data = {str(k): v for k, v in memory.items()}
            IO.write_lines(tune_log, json.dumps(data))
コード例 #26
0
 def write_lines(lines, path):
     log.info(f"Storing data at {path}")
     with IO.writer(path) as f:
         for line in lines:
             f.write(line)
             f.write('\n')
コード例 #27
0
ファイル: trainer.py プロジェクト: MGheini/rtg
    def __init__(self,
                 exp: Experiment,
                 model: Optional[NMTModel] = None,
                 model_factory: Optional[Callable] = None,
                 optim: str = 'ADAM',
                 **optim_args):
        self.start_step = 0
        self.last_step = -1
        self.exp = exp
        optim_state = None
        if model:
            self.model = model
        else:
            args = exp.model_args
            assert args
            assert model_factory
            self.model, args = model_factory(exp=exp, **args)
            exp.model_args = args
            last_model, self.last_step = self.exp.get_last_saved_model()
            if last_model:
                self.start_step = self.last_step + 1
                log.info(
                    f"Resuming training from step:{self.start_step}, model={last_model}"
                )
                state = torch.load(last_model)
                model_state = state[
                    'model_state'] if 'model_state' in state else state
                if 'optim_state' in state:
                    optim_state = state['optim_state']
                self.model.load_state_dict(model_state)
            else:
                log.info(
                    "No earlier check point found. Looks like this is a fresh start"
                )

        # making optimizer
        optim_args['lr'] = optim_args.get('lr', 0.1)
        optim_args['betas'] = optim_args.get('betas', [0.9, 0.98])
        optim_args['eps'] = optim_args.get('eps', 1e-9)

        warmup_steps = optim_args.pop('warmup_steps', 8000)
        self._smoothing = optim_args.pop('label_smoothing', 0.1)
        constant = optim_args.pop('constant', 2)

        self.model = self.model.to(device)

        inner_opt = Optims[optim].new(self.model.parameters(), **optim_args)
        if optim_state:
            log.info("restoring optimizer state from checkpoint")
            try:
                inner_opt.load_state_dict(optim_state)
            except Exception:
                log.exception("Unable to restore optimizer, skipping it.")
        self.opt = NoamOpt(self.model.model_dim,
                           constant,
                           warmup_steps,
                           inner_opt,
                           step=self.start_step)

        optim_args.update(
            dict(warmup_steps=warmup_steps,
                 label_smoothing=self._smoothing,
                 constant=constant))
        if self.exp.read_only:
            self.tbd = NoOpSummaryWriter()
        else:
            self.tbd = SummaryWriter(log_dir=str(exp.work_dir / 'tensorboard'))

        self.exp.optim_args = optim, optim_args
        if not self.exp.read_only:
            self.exp.persist_state()
        self.samples = None
        if exp.samples_file.exists():
            with IO.reader(exp.samples_file) as f:
                self.samples = [line.strip().split('\t') for line in f]
                log.info(f"Found {len(self.samples)} sample records")
                if self.start_step == 0:
                    for samp_num, sample in enumerate(self.samples):
                        self.tbd.add_text(f"sample/{samp_num}",
                                          " || ".join(sample), 0)

            from rtg.module.decoder import Decoder
            self.decoder = Decoder.new(self.exp, self.model)

        if self.start_step == 0:
            self.init_embeddings()
        self.model = self.model.to(device)
コード例 #28
0
ファイル: exp.py プロジェクト: isi-nlp/rtg
    def inherit_parent(self):
        parent = self.config['parent']
        parent_exp = TranslationExperiment(parent['experiment'],
                                           read_only=True)
        log.info(f"Parent experiment: {parent_exp.work_dir}")
        parent_exp.has_prepared()
        vocab_sepc = parent.get('vocab')
        if vocab_sepc:
            log.info(f"Parent vocabs inheritance spec: {vocab_sepc}")
            codec_lib = parent_exp.config['prep'].get('codec_lib')
            if codec_lib:
                self.config['prep']['codec_lib'] = codec_lib

            def _locate_field_file(exp: TranslationExperiment,
                                   name,
                                   check_exists=False) -> Path:
                switch = {
                    'src': exp._src_field_file,
                    'tgt': exp._tgt_field_file,
                    'shared': exp._shared_field_file
                }
                assert name in switch, f'{name} not allowed; valid options= {switch.keys()}'
                file = switch[name]
                if check_exists:
                    assert file.exists(
                    ), f'{file} doesnot exist; for {name} of {exp.work_dir}'
                return file

            for to_field, from_field in vocab_sepc.items():
                from_field_file = _locate_field_file(parent_exp,
                                                     from_field,
                                                     check_exists=True)
                to_field_file = _locate_field_file(self,
                                                   to_field,
                                                   check_exists=False)
                IO.copy_file(from_field_file, to_field_file)
            self.reload_vocabs()
        else:
            log.info("No vocabularies are inherited from parent")
        model_sepc = parent.get('model')
        if model_sepc:
            log.info("Parent model inheritance spec")
            if model_sepc.get('args'):
                self.model_args = parent_exp.model_args
            ensemble = model_sepc.get('ensemble', 1)
            model_paths = parent_exp.list_models(sort='step',
                                                 desc=True)[:ensemble]
            log.info(
                f"Averaging {len(model_paths)} checkpoints of parent model: \n{model_paths}"
            )
            from rtg.module.decoder import Decoder
            avg_state = Decoder.average_states(model_paths=model_paths)
            log.info(
                f"Saving parent model's state to {self.parent_model_state}")
            torch.save(avg_state, self.parent_model_state)

        shrink_spec = parent.get('shrink')
        if shrink_spec:
            remap_src, remap_tgt = self.shrink_vocabs()

            def map_rows(mapping: List[int], source: torch.Tensor, name=''):
                assert max(mapping) < len(source)
                target = torch.zeros((len(mapping), *source.shape[1:]),
                                     dtype=source.dtype,
                                     device=source.device)
                for new_idx, old_idx in enumerate(mapping):
                    target[new_idx] = source[old_idx]
                log.info(f"Mapped {name} {source.shape} --> {target.shape} ")
                return target

            """ src_embed.0.lut.weight [N x d]
                tgt_embed.0.lut.weight [N x d]
                generator.proj.weight [N x d]
                generator.proj.bias [N] """
            if remap_src:
                key = 'src_embed.0.lut.weight'
                avg_state[key] = map_rows(remap_src, avg_state[key], name=key)
            if remap_tgt:
                map_keys = [
                    'tgt_embed.0.lut.weight', 'generator.proj.weight',
                    'generator.proj.bias'
                ]
                for key in map_keys:
                    if key not in avg_state:
                        log.warning(
                            f'{key} not found in avg_state of parent model. Mapping skipped'
                        )
                        continue
                    avg_state[key] = map_rows(remap_tgt,
                                              avg_state[key],
                                              name=key)
            if self.parent_model_state.exists():
                self.parent_model_state.rename(
                    self.parent_model_state.with_suffix('.orig'))
            torch.save(avg_state, self.parent_model_state)
            self.persist_state(
            )  # this will fix src_vocab and tgt_vocab of model_args conf
コード例 #29
0
ファイル: exp.py プロジェクト: MGheini/rtg
 def _write_dict(dict, path: Path):
     with IO.writer(path) as out:
         for key, val in dict.items():
             out.write(f"{key}\t{val}\n")
コード例 #30
0
ファイル: exp.py プロジェクト: MGheini/rtg
def load_conf(inp: Union[str, Path]):
    with IO.reader(inp) as fh:
        return yaml.load(fh)