def __init__(self, models: List[Path], exp: Union[Path, TranslationExperiment], lr: float = 1e-4, smoothing=0.1): if isinstance(exp, Path): exp = TranslationExperiment(exp) self.w_file = exp.work_dir / f'combo-weights.yml' wt = None if self.w_file.exists(): with IO.reader(self.w_file) as rdr: combo_spec = yaml.load(rdr) weights = combo_spec['weights'] assert len(weights) == len( models) # same models as before: no messing allowed model_path_strs = [str(m) for m in models] for m in model_path_strs: assert m in weights, f'{m} not found in weights file.' wt = [weights[str(m)] for m in model_path_strs] log.info(f"restoring previously stored weights {wt}") from rtg.module.decoder import load_models combo = Combo(load_models(models, exp), model_paths=models, w=wt) self.combo = combo.to(device) self.exp = exp self.optim = torch.optim.Adam(combo.parameters(), lr=lr) self.criterion = SmoothKLD(vocab_size=combo.vocab_size, padding_idx=exp.tgt_vocab.pad_idx, smoothing=smoothing)
def start_step(self): _, step = self.exp.get_last_saved_model() if self.exp._trained_flag.exists(): # noinspection PyBroadException try: step = max( step, yaml.load(self.exp._trained_flag.read_text())['steps']) except Exception as _: pass assert step >= 0 return step
def validate_args(args, exp: Experiment): if not args.pop('skip_check'): # if --skip-check is not requested assert exp.has_prepared(), \ f'Experiment dir {exp.work_dir} is not ready to train. Please run "prep" sub task' assert exp.has_trained(), \ f'Experiment dir {exp.work_dir} is not ready to decode.' \ f' Please run "train" sub task or --skip-check to ignore this' weights_file = exp.work_dir / 'combo-weights.yml' if not args.get('sys_comb') and weights_file.exists(): log.warning("Found default combo weights, switching to combo mode") args['sys_comb'] = weights_file if args.get("sys_comb"): with IO.reader(args['sys_comb']) as fh: weights = yaml.load(fh)['weights'] args['model_path'], args['weights'] = zip(*weights.items()) for model in args['model_path']: assert Path(model).exists(), model assert abs(sum(args['weights']) - 1) < 1e-3, \ f'Weights from --sys-comb file should sum to 1.0, given={args["weights"]}'
def __init__(self, path: Union[str, Path]): with IO.reader(path) as rdr: data = yaml.load(rdr) hub_api = self.load_hub_model(data['model_id']) # these are for XML-R wiz RoBERTa from fairseq ; generalize it for other models later self.bpe = hub_api.bpe self.tok2idx = { tok: new_idx for tok, (new_idx, old_idx) in data['mapping'].items() } self.idx2tok = list( sorted(self.tok2idx.keys(), key=self.tok2idx.get, reverse=False)) assert len(self.idx2tok) == len(self.tok2idx) for tok, idx in self.reserved(): # reserved are reserved assert self.tok2idx[tok] == idx assert self.idx2tok[idx] == tok self.new_idx2old_idx = { new_idx: old_idx for tok, (new_idx, old_idx) in data['mapping'].items() }
def train(self, args=None): run_args = copy.deepcopy(self.config.get('trainer', {})) if args: run_args.update(args) if 'init_args' in run_args: del run_args['init_args'] train_steps = run_args['steps'] finetune_steps = run_args.pop('finetune_steps', None) finetune_batch_size = run_args.pop('finetune_batch_size', run_args.get('batch_size')) if finetune_steps: assert type(finetune_steps) is int assert finetune_steps > train_steps, f'finetune_steps={finetune_steps} should be' \ f' greater than steps={train_steps}' _, last_step = self.get_last_saved_model() if self._trained_flag.exists(): # noinspection PyBroadException try: last_step = max( last_step, yaml.load(self._trained_flag.read_text())['steps']) except Exception as _: pass if last_step >= train_steps and (finetune_steps is None or last_step >= finetune_steps): log.warning( f"Already trained upto {last_step}; Requested: train={train_steps}, finetune={finetune_steps} Skipped" ) return from rtg.registry import trainers, factories name, optim_args = self.optim_args trainer = trainers[self.model_type]( self, optim=name, model_factory=factories[self.model_type], **optim_args) if last_step < train_steps: # regular training stopped = trainer.train(fine_tune=False, **run_args) if not self.read_only: status = dict(steps=train_steps, early_stopped=stopped, finetune=False) try: status['earlier'] = yaml.load( self._trained_flag.read_text()) except Exception as _: pass yaml.dump(status, stream=self._trained_flag) if finetune_steps: # Fine tuning log.info( f"Fine tuning upto {finetune_steps}, batch_size={finetune_batch_size}" ) assert finetune_batch_size run_args['steps'] = finetune_steps run_args['batch_size'] = finetune_batch_size stopped = trainer.train(fine_tune=True, **run_args) status = dict(steps=finetune_steps, early_stopped=stopped, finetune=True) try: status['earlier'] = yaml.load(self._trained_flag.read_text()) except Exception as _: pass yaml.dump(status, stream=self._trained_flag)
def load_conf(inp: Union[str, Path]): with IO.reader(inp) as fh: return yaml.load(fh)