示例#1
0
    def __init__(self,
                 models: List[Path],
                 exp: Union[Path, TranslationExperiment],
                 lr: float = 1e-4,
                 smoothing=0.1):
        if isinstance(exp, Path):
            exp = TranslationExperiment(exp)
        self.w_file = exp.work_dir / f'combo-weights.yml'

        wt = None
        if self.w_file.exists():
            with IO.reader(self.w_file) as rdr:
                combo_spec = yaml.load(rdr)
            weights = combo_spec['weights']
            assert len(weights) == len(
                models)  # same models as before: no messing allowed
            model_path_strs = [str(m) for m in models]
            for m in model_path_strs:
                assert m in weights, f'{m} not found in weights file.'
            wt = [weights[str(m)] for m in model_path_strs]
            log.info(f"restoring previously stored weights {wt}")

        from rtg.module.decoder import load_models
        combo = Combo(load_models(models, exp), model_paths=models, w=wt)
        self.combo = combo.to(device)
        self.exp = exp
        self.optim = torch.optim.Adam(combo.parameters(), lr=lr)
        self.criterion = SmoothKLD(vocab_size=combo.vocab_size,
                                   padding_idx=exp.tgt_vocab.pad_idx,
                                   smoothing=smoothing)
示例#2
0
 def start_step(self):
     _, step = self.exp.get_last_saved_model()
     if self.exp._trained_flag.exists():
         # noinspection PyBroadException
         try:
             step = max(
                 step,
                 yaml.load(self.exp._trained_flag.read_text())['steps'])
         except Exception as _:
             pass
     assert step >= 0
     return step
示例#3
0
def validate_args(args, exp: Experiment):
    if not args.pop('skip_check'):  # if --skip-check is not requested
        assert exp.has_prepared(), \
            f'Experiment dir {exp.work_dir} is not ready to train. Please run "prep" sub task'
        assert exp.has_trained(), \
            f'Experiment dir {exp.work_dir} is not ready to decode.' \
            f' Please run "train" sub task or --skip-check to ignore this'

    weights_file = exp.work_dir / 'combo-weights.yml'
    if not args.get('sys_comb') and weights_file.exists():
        log.warning("Found default combo weights, switching to combo mode")
        args['sys_comb'] = weights_file

    if args.get("sys_comb"):
        with IO.reader(args['sys_comb']) as fh:
            weights = yaml.load(fh)['weights']
            args['model_path'], args['weights'] = zip(*weights.items())
            for model in args['model_path']:
                assert Path(model).exists(), model
            assert abs(sum(args['weights']) - 1) < 1e-3, \
                f'Weights from --sys-comb file should sum to 1.0, given={args["weights"]}'
示例#4
0
文件: codec.py 项目: isi-nlp/rtg
    def __init__(self, path: Union[str, Path]):
        with IO.reader(path) as rdr:
            data = yaml.load(rdr)
        hub_api = self.load_hub_model(data['model_id'])
        # these are for XML-R wiz RoBERTa from fairseq  ; generalize it for other models later
        self.bpe = hub_api.bpe

        self.tok2idx = {
            tok: new_idx
            for tok, (new_idx, old_idx) in data['mapping'].items()
        }
        self.idx2tok = list(
            sorted(self.tok2idx.keys(), key=self.tok2idx.get, reverse=False))
        assert len(self.idx2tok) == len(self.tok2idx)

        for tok, idx in self.reserved():  # reserved are reserved
            assert self.tok2idx[tok] == idx
            assert self.idx2tok[idx] == tok
        self.new_idx2old_idx = {
            new_idx: old_idx
            for tok, (new_idx, old_idx) in data['mapping'].items()
        }
示例#5
0
文件: exp.py 项目: isi-nlp/rtg
    def train(self, args=None):
        run_args = copy.deepcopy(self.config.get('trainer', {}))
        if args:
            run_args.update(args)
        if 'init_args' in run_args:
            del run_args['init_args']
        train_steps = run_args['steps']
        finetune_steps = run_args.pop('finetune_steps', None)
        finetune_batch_size = run_args.pop('finetune_batch_size',
                                           run_args.get('batch_size'))
        if finetune_steps:
            assert type(finetune_steps) is int
            assert finetune_steps > train_steps, f'finetune_steps={finetune_steps} should be' \
                                                 f' greater than steps={train_steps}'

        _, last_step = self.get_last_saved_model()
        if self._trained_flag.exists():
            # noinspection PyBroadException
            try:
                last_step = max(
                    last_step,
                    yaml.load(self._trained_flag.read_text())['steps'])
            except Exception as _:
                pass

        if last_step >= train_steps and (finetune_steps is None
                                         or last_step >= finetune_steps):
            log.warning(
                f"Already trained upto {last_step}; Requested: train={train_steps}, finetune={finetune_steps} Skipped"
            )
            return

        from rtg.registry import trainers, factories
        name, optim_args = self.optim_args
        trainer = trainers[self.model_type](
            self,
            optim=name,
            model_factory=factories[self.model_type],
            **optim_args)
        if last_step < train_steps:  # regular training
            stopped = trainer.train(fine_tune=False, **run_args)
            if not self.read_only:
                status = dict(steps=train_steps,
                              early_stopped=stopped,
                              finetune=False)
                try:
                    status['earlier'] = yaml.load(
                        self._trained_flag.read_text())
                except Exception as _:
                    pass
                yaml.dump(status, stream=self._trained_flag)
        if finetune_steps:  # Fine tuning
            log.info(
                f"Fine tuning upto {finetune_steps}, batch_size={finetune_batch_size}"
            )
            assert finetune_batch_size
            run_args['steps'] = finetune_steps
            run_args['batch_size'] = finetune_batch_size

            stopped = trainer.train(fine_tune=True, **run_args)
            status = dict(steps=finetune_steps,
                          early_stopped=stopped,
                          finetune=True)
            try:
                status['earlier'] = yaml.load(self._trained_flag.read_text())
            except Exception as _:
                pass
            yaml.dump(status, stream=self._trained_flag)
示例#6
0
文件: exp.py 项目: isi-nlp/rtg
def load_conf(inp: Union[str, Path]):
    with IO.reader(inp) as fh:
        return yaml.load(fh)