def write(cls, path, records: Iterator[ParallelSeqRecord]): if path.exists(): log.warning(f"Overwriting {path} with new records") os.remove(str(path)) maybe_tmp = IO.maybe_tmpfs(path) log.info(f'Creating {maybe_tmp}') conn = sqlite3.connect(str(maybe_tmp)) cur = conn.cursor() cur.execute(cls.TABLE_STATEMENT) cur.execute(cls.INDEX_X_LEN) cur.execute(cls.INDEX_Y_LEN) cur.execute(f"PRAGMA user_version = {cls.CUR_VERSION};") count = 0 for x_seq, y_seq in records: # use numpy. its a lot efficient if not isinstance(x_seq, np.ndarray): x_seq = np.array(x_seq, dtype=np.int32) if y_seq is not None and not isinstance(y_seq, np.ndarray): y_seq = np.array(y_seq, dtype=np.int32) values = (x_seq.tobytes(), None if y_seq is None else y_seq.tobytes(), len(x_seq), len(y_seq) if y_seq is not None else -1) cur.execute(cls.INSERT_STMT, values) count += 1 cur.close() conn.commit() if maybe_tmp != path: # bring the file back to original location where it should be IO.copy_file(maybe_tmp, path) log.info(f"stored {count} rows in {path}")
def fork_experiment(from_exp: Path, to_exp: Path, conf: bool, vocab: bool, data: bool, code: bool): assert from_exp.exists() log.info(f'Fork: {str(from_exp)} → {str(to_exp)}') if not to_exp.exists(): log.info(f"Create dir {str(to_exp)}") to_exp.mkdir(parents=True) if conf: conf_file = to_exp / 'conf.yml' IO.maybe_backup(conf_file) IO.copy_file(from_exp / 'conf.yml', conf_file) if data: to_data_dir = (to_exp / 'data') from_data_dir = from_exp / 'data' if to_data_dir.is_symlink(): log.info(f"removing the existing data link: {to_data_dir.resolve()}") to_data_dir.unlink() assert not to_data_dir.exists() assert from_data_dir.exists() log.info(f"link {to_data_dir} → {from_data_dir}") to_data_dir.symlink_to(from_data_dir.resolve()) (to_exp / '_PREPARED').touch(exist_ok=True) if not data and vocab: # just the vocab Experiment(from_exp, read_only=True).copy_vocabs( Experiment(to_exp, config={'Not': 'Empty'}, read_only=True)) if code: for f in ['rtg.zip', 'githead']: src = from_exp / f if not src.exists(): log.warning(f"File Not Found: {src}") continue IO.copy_file(src, to_exp / f)
def copy_vocabs(self, other): """ Copies vocabulary files from self to other :param other: other experiment :return: """ other: TranslationExperiment = other if not other.data_dir.exists(): other.data_dir.mkdir(parents=True) for source, destination in [ (self._src_field_file, other._src_field_file), (self._tgt_field_file, other._tgt_field_file), (self._shared_field_file, other._shared_field_file) ]: if source.exists(): IO.copy_file(source, destination) src_txt_file = source.with_name( source.name.replace('.model', '.vocab')) if src_txt_file.exists(): dst_txt_file = destination.with_name( destination.name.replace('.model', '.vocab')) IO.copy_file(src_txt_file, dst_txt_file)
def inherit_parent(self): parent = self.config['parent'] parent_exp = TranslationExperiment(parent['experiment'], read_only=True) log.info(f"Parent experiment: {parent_exp.work_dir}") parent_exp.has_prepared() vocab_sepc = parent.get('vocab') if vocab_sepc: log.info(f"Parent vocabs inheritance spec: {vocab_sepc}") codec_lib = parent_exp.config['prep'].get('codec_lib') if codec_lib: self.config['prep']['codec_lib'] = codec_lib def _locate_field_file(exp: TranslationExperiment, name, check_exists=False) -> Path: switch = { 'src': exp._src_field_file, 'tgt': exp._tgt_field_file, 'shared': exp._shared_field_file } assert name in switch, f'{name} not allowed; valid options= {switch.keys()}' file = switch[name] if check_exists: assert file.exists( ), f'{file} doesnot exist; for {name} of {exp.work_dir}' return file for to_field, from_field in vocab_sepc.items(): from_field_file = _locate_field_file(parent_exp, from_field, check_exists=True) to_field_file = _locate_field_file(self, to_field, check_exists=False) IO.copy_file(from_field_file, to_field_file) self.reload_vocabs() else: log.info("No vocabularies are inherited from parent") model_sepc = parent.get('model') if model_sepc: log.info("Parent model inheritance spec") if model_sepc.get('args'): self.model_args = parent_exp.model_args ensemble = model_sepc.get('ensemble', 1) model_paths = parent_exp.list_models(sort='step', desc=True)[:ensemble] log.info( f"Averaging {len(model_paths)} checkpoints of parent model: \n{model_paths}" ) from rtg.module.decoder import Decoder avg_state = Decoder.average_states(model_paths=model_paths) log.info( f"Saving parent model's state to {self.parent_model_state}") torch.save(avg_state, self.parent_model_state) shrink_spec = parent.get('shrink') if shrink_spec: remap_src, remap_tgt = self.shrink_vocabs() def map_rows(mapping: List[int], source: torch.Tensor, name=''): assert max(mapping) < len(source) target = torch.zeros((len(mapping), *source.shape[1:]), dtype=source.dtype, device=source.device) for new_idx, old_idx in enumerate(mapping): target[new_idx] = source[old_idx] log.info(f"Mapped {name} {source.shape} --> {target.shape} ") return target """ src_embed.0.lut.weight [N x d] tgt_embed.0.lut.weight [N x d] generator.proj.weight [N x d] generator.proj.bias [N] """ if remap_src: key = 'src_embed.0.lut.weight' avg_state[key] = map_rows(remap_src, avg_state[key], name=key) if remap_tgt: map_keys = [ 'tgt_embed.0.lut.weight', 'generator.proj.weight', 'generator.proj.bias' ] for key in map_keys: if key not in avg_state: log.warning( f'{key} not found in avg_state of parent model. Mapping skipped' ) continue avg_state[key] = map_rows(remap_tgt, avg_state[key], name=key) if self.parent_model_state.exists(): self.parent_model_state.rename( self.parent_model_state.with_suffix('.orig')) torch.save(avg_state, self.parent_model_state) self.persist_state( ) # this will fix src_vocab and tgt_vocab of model_args conf
def export(self, target: Path, name: str = None, ensemble: int = 1, copy_config=True, copy_vocab=True): to_exp = Experiment(target.resolve(), config=self.exp.config) if copy_config: log.info("Copying config") to_exp.persist_state() if copy_vocab: log.info("Copying vocabulary") self.exp.copy_vocabs(to_exp) assert ensemble > 0 assert name assert len(name.split()) == 1 log.info("Going to average models and then copy") model_paths = self.exp.list_models()[:ensemble] log.info(f'Model paths: {model_paths}') chkpt_state = torch.load(model_paths[0], map_location=device) if ensemble > 1: log.info("Averaging them ...") avg_state = Decoder.average_states(model_paths) chkpt_state = dict(model_state=avg_state, model_type=chkpt_state['model_type'], model_args=chkpt_state['model_args']) log.info("Instantiating it ...") model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp) log.info(f"Exporting to {target}") to_exp = Experiment(target, config=self.exp.config) to_exp.persist_state() IO.copy_file(self.exp.model_dir / 'scores.tsv', to_exp.model_dir / 'scores.tsv') if (self.exp.work_dir / 'rtg.zip').exists(): IO.copy_file(self.exp.work_dir / 'rtg.zip', to_exp.work_dir / 'rtg.zip') src_chkpt = chkpt_state log.warning( "step number, training loss and validation loss are not recalculated." ) step_num, train_loss, val_loss = [ src_chkpt.get(n, -1) for n in ['step', 'train_loss', 'val_loss'] ] copy_fields = [ 'optim_state', 'step', 'train_loss', 'valid_loss', 'time', 'rtg_version', 'model_type', 'model_args' ] state = dict((c, src_chkpt[c]) for c in copy_fields if c in src_chkpt) state['model_state'] = model.state_dict() state['averaged_time'] = time.time() state['model_paths'] = model_paths state['num_checkpts'] = len(model_paths) prefix = f'model_{name}_avg{len(model_paths)}' to_exp.store_model(step_num, state, train_score=train_loss, val_score=val_loss, keep=10, prefix=prefix) chkpts = [mp.name for mp in model_paths] status = { 'parent': str(self.exp.work_dir), 'ensemble': ensemble, 'checkpts': chkpts, 'when': datetime.datetime.now().isoformat(), 'who': os.environ.get('USER', '<unknown>'), } yaml.dump(status, stream=to_exp.work_dir / '_EXPORTED') if self.exp._trained_flag.exists(): IO.copy_file(self.exp._trained_flag, to_exp._trained_flag)