def export(self, target: Path, name: str = None, ensemble: int = 1, copy_config=True, copy_vocab=True): to_exp = Experiment(target, config=self.exp.config) if copy_config: log.info("Copying config") to_exp.persist_state() if copy_vocab: log.info("Copying vocabulary") self.exp.copy_vocabs(to_exp) if ensemble > 0: assert name assert len(name.split()) == 1 log.info("Going to average models and then copy") model_paths = self.exp.list_models()[:ensemble] log.info(f'Model paths: {model_paths}') checkpts = [torch.load(mp) for mp in model_paths] states = [chkpt['model_state'] for chkpt in checkpts] log.info("Averaging them ...") avg_state = Decoder.average_states(*states) chkpt_state = dict(model_state=avg_state, model_type=checkpts[0]['model_type'], model_args=checkpts[0]['model_args']) log.info("Instantiating it ...") model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp) log.info(f"Exporting to {target}") to_exp = Experiment(target, config=self.exp.config) to_exp.persist_state() src_chkpt = checkpts[0] log.warning( "step number, training loss and validation loss are not recalculated." ) step_num, train_loss, val_loss = [ src_chkpt.get(n, -1) for n in ['step', 'train_loss', 'val_loss'] ] copy_fields = [ 'optim_state', 'step', 'train_loss', 'valid_loss', 'time', 'rtg_version', 'model_type', 'model_args' ] state = dict( (c, src_chkpt[c]) for c in copy_fields if c in src_chkpt) state['model_state'] = model.state_dict() state['averaged_time'] = time.time() state['model_paths'] = model_paths state['num_checkpts'] = len(model_paths) prefix = f'model_{name}_avg{len(model_paths)}' to_exp.store_model(step_num, state, train_score=train_loss, val_score=val_loss, keep=10, prefix=prefix)
def inherit_parent(self): parent = self.config['parent'] parent_exp = TranslationExperiment(parent['experiment'], read_only=True) log.info(f"Parent experiment: {parent_exp.work_dir}") parent_exp.has_prepared() vocab_sepc = parent.get('vocab') if vocab_sepc: log.info(f"Parent vocabs inheritance spec: {vocab_sepc}") codec_lib = parent_exp.config['prep'].get('codec_lib') if codec_lib: self.config['prep']['codec_lib'] = codec_lib def _locate_field_file(exp: TranslationExperiment, name, check_exists=False) -> Path: switch = { 'src': exp._src_field_file, 'tgt': exp._tgt_field_file, 'shared': exp._shared_field_file } assert name in switch, f'{name} not allowed; valid options= {switch.keys()}' file = switch[name] if check_exists: assert file.exists( ), f'{file} doesnot exist; for {name} of {exp.work_dir}' return file for to_field, from_field in vocab_sepc.items(): from_field_file = _locate_field_file(parent_exp, from_field, check_exists=True) to_field_file = _locate_field_file(self, to_field, check_exists=False) IO.copy_file(from_field_file, to_field_file) self.reload_vocabs() else: log.info("No vocabularies are inherited from parent") model_sepc = parent.get('model') if model_sepc: log.info("Parent model inheritance spec") if model_sepc.get('args'): self.model_args = parent_exp.model_args ensemble = model_sepc.get('ensemble', 1) model_paths = parent_exp.list_models(sort='step', desc=True)[:ensemble] log.info( f"Averaging {len(model_paths)} checkpoints of parent model: \n{model_paths}" ) from rtg.module.decoder import Decoder avg_state = Decoder.average_states(model_paths=model_paths) log.info( f"Saving parent model's state to {self.parent_model_state}") torch.save(avg_state, self.parent_model_state) shrink_spec = parent.get('shrink') if shrink_spec: remap_src, remap_tgt = self.shrink_vocabs() def map_rows(mapping: List[int], source: torch.Tensor, name=''): assert max(mapping) < len(source) target = torch.zeros((len(mapping), *source.shape[1:]), dtype=source.dtype, device=source.device) for new_idx, old_idx in enumerate(mapping): target[new_idx] = source[old_idx] log.info(f"Mapped {name} {source.shape} --> {target.shape} ") return target """ src_embed.0.lut.weight [N x d] tgt_embed.0.lut.weight [N x d] generator.proj.weight [N x d] generator.proj.bias [N] """ if remap_src: key = 'src_embed.0.lut.weight' avg_state[key] = map_rows(remap_src, avg_state[key], name=key) if remap_tgt: map_keys = [ 'tgt_embed.0.lut.weight', 'generator.proj.weight', 'generator.proj.bias' ] for key in map_keys: if key not in avg_state: log.warning( f'{key} not found in avg_state of parent model. Mapping skipped' ) continue avg_state[key] = map_rows(remap_tgt, avg_state[key], name=key) if self.parent_model_state.exists(): self.parent_model_state.rename( self.parent_model_state.with_suffix('.orig')) torch.save(avg_state, self.parent_model_state) self.persist_state( ) # this will fix src_vocab and tgt_vocab of model_args conf
def export(self, target: Path, name: str = None, ensemble: int = 1, copy_config=True, copy_vocab=True): to_exp = Experiment(target.resolve(), config=self.exp.config) if copy_config: log.info("Copying config") to_exp.persist_state() if copy_vocab: log.info("Copying vocabulary") self.exp.copy_vocabs(to_exp) assert ensemble > 0 assert name assert len(name.split()) == 1 log.info("Going to average models and then copy") model_paths = self.exp.list_models()[:ensemble] log.info(f'Model paths: {model_paths}') chkpt_state = torch.load(model_paths[0], map_location=device) if ensemble > 1: log.info("Averaging them ...") avg_state = Decoder.average_states(model_paths) chkpt_state = dict(model_state=avg_state, model_type=chkpt_state['model_type'], model_args=chkpt_state['model_args']) log.info("Instantiating it ...") model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp) log.info(f"Exporting to {target}") to_exp = Experiment(target, config=self.exp.config) to_exp.persist_state() IO.copy_file(self.exp.model_dir / 'scores.tsv', to_exp.model_dir / 'scores.tsv') if (self.exp.work_dir / 'rtg.zip').exists(): IO.copy_file(self.exp.work_dir / 'rtg.zip', to_exp.work_dir / 'rtg.zip') src_chkpt = chkpt_state log.warning( "step number, training loss and validation loss are not recalculated." ) step_num, train_loss, val_loss = [ src_chkpt.get(n, -1) for n in ['step', 'train_loss', 'val_loss'] ] copy_fields = [ 'optim_state', 'step', 'train_loss', 'valid_loss', 'time', 'rtg_version', 'model_type', 'model_args' ] state = dict((c, src_chkpt[c]) for c in copy_fields if c in src_chkpt) state['model_state'] = model.state_dict() state['averaged_time'] = time.time() state['model_paths'] = model_paths state['num_checkpts'] = len(model_paths) prefix = f'model_{name}_avg{len(model_paths)}' to_exp.store_model(step_num, state, train_score=train_loss, val_score=val_loss, keep=10, prefix=prefix) chkpts = [mp.name for mp in model_paths] status = { 'parent': str(self.exp.work_dir), 'ensemble': ensemble, 'checkpts': chkpts, 'when': datetime.datetime.now().isoformat(), 'who': os.environ.get('USER', '<unknown>'), } yaml.dump(status, stream=to_exp.work_dir / '_EXPORTED') if self.exp._trained_flag.exists(): IO.copy_file(self.exp._trained_flag, to_exp._trained_flag)