def write(cls, path, records: Iterator[ParallelSeqRecord]): if path.exists(): log.warning(f"Overwriting {path} with new records") os.remove(str(path)) maybe_tmp = IO.maybe_tmpfs(path) log.info(f'Creating {maybe_tmp}') conn = sqlite3.connect(str(maybe_tmp)) cur = conn.cursor() cur.execute(cls.TABLE_STATEMENT) cur.execute(cls.INDEX_X_LEN) cur.execute(cls.INDEX_Y_LEN) cur.execute(f"PRAGMA user_version = {cls.CUR_VERSION};") count = 0 for x_seq, y_seq in records: # use numpy. its a lot efficient if not isinstance(x_seq, np.ndarray): x_seq = np.array(x_seq, dtype=np.int32) if y_seq is not None and not isinstance(y_seq, np.ndarray): y_seq = np.array(y_seq, dtype=np.int32) values = (x_seq.tobytes(), None if y_seq is None else y_seq.tobytes(), len(x_seq), len(y_seq) if y_seq is not None else -1) cur.execute(cls.INSERT_STMT, values) count += 1 cur.close() conn.commit() if maybe_tmp != path: # bring the file back to original location where it should be IO.copy_file(maybe_tmp, path) log.info(f"stored {count} rows in {path}")
def fork_experiment(from_exp: Path, to_exp: Path, conf: bool, vocab: bool, data: bool): assert from_exp.exists() log.info(f'Fork: {str(from_exp)} → {str(to_exp)}') if not to_exp.exists(): log.info(f"Create dir {str(to_exp)}") to_exp.mkdir(parents=True) if conf: conf_file = to_exp / 'conf.yml' IO.maybe_backup(conf_file) IO.copy_file(from_exp / 'conf.yml', conf_file) if data: to_data_dir = (to_exp / 'data') from_data_dir = from_exp / 'data' if to_data_dir.is_symlink(): log.info( f"removing the existing data link: {to_data_dir.resolve()}") to_data_dir.unlink() assert not to_data_dir.exists() assert from_data_dir.exists() log.info(f"link {to_data_dir} → {from_data_dir}") to_data_dir.symlink_to(from_data_dir.resolve()) (to_exp / '_PREPARED').touch(exist_ok=True) if not data and vocab: # just the vocab Experiment(from_exp, read_only=True).copy_vocabs( Experiment(to_exp, config={'Not': 'Empty'}, read_only=True))
def read_raw_parallel_lines(src_path: Union[str, Path], tgt_path: Union[str, Path]) \ -> Iterator[RawRecord]: with IO.reader(src_path) as src_lines, IO.reader( tgt_path) as tgt_lines: # if you get an exception here --> files have un equal number of lines recs = ((src.strip(), tgt.strip()) for src, tgt in zip_longest(src_lines, tgt_lines)) recs = ((src, tgt) for src, tgt in recs if src and tgt) yield from recs
def write_parallel(data, src_file, tgt_file): count = 0 with IO.writer(src_file) as src_f, IO.writer(tgt_file) as tgt_f: for src_seq, tgt_seq in data: src_seq = ' '.join(map(str, src_seq)) tgt_seq = ' '.join(map(str, tgt_seq)) src_f.write(f'{src_seq}\n') tgt_f.write(f'{tgt_seq}\n') count += 1 log.info(f"Wrote {count} records to {src_file} and {tgt_file}")
def moses_detokenize(self, inp: Path, out: Path, col=0, lang='en', post_op=None): log.info(f"detok : {inp} --> {out}") tok_lines = IO.get_lines(inp, col=col, line_mapper=lambda x: x.split()) with MosesDetokenizer(lang=lang) as detok: detok_lines = (detok(tok_line) for tok_line in tok_lines) if post_op: detok_lines = (post_op(line) for line in detok_lines) IO.write_lines(out, detok_lines)
def train(cls, model_type: str, vocab_size: int, model_path: Union[Path, str], files: List[str], tok_coverage=0.9999, **kwargs): # Note: char_coverage is abused as subword_coverage hub_api = cls.load_hub_model(model_type) bpe = hub_api.bpe dicto = hub_api.task.dictionary freqs = coll.Counter() lines = IO.get_liness(*files) for line in tqdm(lines, mininterval=2, dynamic_ncols=True, unit='line'): freqs.update(bpe.encode(line).split()) total_toks = sum(freqs.values()) log.info(f"Found {len(freqs)} bpe types and {total_toks} toks") freqs = list(sorted(freqs.items(), reverse=True, key=lambda x: x[1])) vocabulary, oovs = [], [] cumulative = 0 for t, f in freqs: if cumulative / total_toks <= tok_coverage: vocabulary.append((t, f)) cumulative += f else: oovs.append((t, f)) oovs_str = ' '.join(f'{t}:{f}' for t, f in oovs) log.info(f'Excluded {len(oovs)} types as OOVs.\n:{oovs_str}') log.info(f'Included {len(vocabulary)} types as in vocabulary; ' f'Coverage = {cumulative / total_toks:g}') # TODO: mapping should be list[int] with one on one map types, indices = [], {} for typ, new_idx in cls.reserved(): assert len(types) == new_idx types.append(typ) old_idx = dicto.indices.get(typ, -1) indices[typ] = [new_idx, old_idx] for typ, freq in vocabulary: # [new index, old index] indices[typ] = [len(types), dicto.indices.get(typ, -1)] types.append(typ) data = {'model_id': model_type, 'mapping': indices} with IO.writer(model_path) as wrtr: yaml.dump(data, wrtr) return cls(model_path)
def train(cls, model_type: str, vocab_size: int, model_path: str, files: List[str], no_split_toks: Optional[List[str]] = None, char_coverage: float = 0, dedup=True, spark=None): """ :param model_type: word, char, bpe :param vocab_size: vocabulary size :param model_path: where to store vocabulary model :param files: text for creating vcabulary :param no_split_toks: :param char_coverage: character coverage (0, 1]. value <= 0 => default coverage :return: """ assert not no_split_toks, 'not supported in nlcodec yet' from nlcodec import learn_vocab, term_freq kwargs = dict(char_coverage=char_coverage) if char_coverage > 0 else {} if not spark: inp = IO.get_liness(*files) else: # extract and store frequencies to this file stats_file = model_path + '.termfreqs' if not Path(stats_file).exists(): log.info("Extracting term frequencies... ") paths = [f if isinstance(f, Path) else Path(f) for f in files] wfs, chfs, n_lines = term_freq.word_counts(paths=paths, dedup=dedup, spark=spark) log.info( f"Lines = {n_lines:,}, Word Types: {len(wfs):,} Char Types:{len(chfs):,}" ) stats = chfs if model_type == 'char' else wfs log.info(f"Writing frequencies to {stats_file}") with IO.writer(stats_file) as out: term_freq.write_stats(stats=stats, out=out, line_count=n_lines) kwargs['term_freqs'] = True inp = IO.get_lines(stats_file, delim='\n') learn_vocab(inp=inp, level=model_type, model=model_path, vocab_size=vocab_size, **kwargs) return cls(model_path)
def __init__(self, models: List[Path], exp: Union[Path, TranslationExperiment], lr: float = 1e-4, smoothing=0.1): if isinstance(exp, Path): exp = TranslationExperiment(exp) self.w_file = exp.work_dir / f'combo-weights.yml' wt = None if self.w_file.exists(): with IO.reader(self.w_file) as rdr: combo_spec = yaml.load(rdr) weights = combo_spec['weights'] assert len(weights) == len( models) # same models as before: no messing allowed model_path_strs = [str(m) for m in models] for m in model_path_strs: assert m in weights, f'{m} not found in weights file.' wt = [weights[str(m)] for m in model_path_strs] log.info(f"restoring previously stored weights {wt}") from rtg.module.decoder import load_models combo = Combo(load_models(models, exp), model_paths=models, w=wt) self.combo = combo.to(device) self.exp = exp self.optim = torch.optim.Adam(combo.parameters(), lr=lr) self.criterion = LabelSmoothing(vocab_size=combo.vocab_size, padding_idx=PAD_TOK_IDX, smoothing=smoothing)
def shell_pipe(cls, cmd_line, inp, out): """ :param cmd_line: shell commandlines :param inp: input file, to read records :param out: output file to store records :return: """ log.info("Shell cmd:: {cmd_line}") with IO.reader(inp) as rdr, IO.writer(out) as wtr: proc = subprocess.Popen(cmd_line, stdin=rdr, stdout=wtr, shell=True) proc.wait() log.info("Shell cmd:: Done")
def _read_vocab(path: Path) -> List[str]: with IO.reader(path) as rdr: vocab = [line.strip().split()[0] for line in rdr] if do_clean: # sentence piece starts with '▁' character vocab = [ word[1:] if word[0] == '▁' else word for word in vocab ] return vocab
def read_raw_mono_recs(path: Union[str, Path], truncate: bool, max_len: int, tokenizer): with IO.reader(path) as lines: recs = (tokenizer(line.strip()) for line in lines if line.strip()) if truncate: recs = (rec[:max_len] for rec in recs) else: # Filter out longer sentences recs = (rec for rec in recs if 0 < len(rec) <= max_len) yield from recs
def evaluate_file(self, detok_hyp: Path, ref: Union[Path, List[str]], lowercase=True) -> float: detok_lines = list(IO.get_lines(detok_hyp)) # takes multiple refs, but here we have only one ref_liness = [IO.get_lines(ref) if isinstance(ref, Path) else ref] bleu: BLEUScore = corpus_bleu(sys_stream=detok_lines, ref_streams=ref_liness, lowercase=lowercase) # this should be part of new sacrebleu release (i sent a PR ;) bleu_str = bleu.format() bleu_file = detok_hyp.with_name(detok_hyp.name + ('.lc' if lowercase else '.oc') + '.sacrebleu') log.info(f'BLEU {detok_hyp} : {bleu_str}') IO.write_lines(bleu_file, bleu_str) return bleu.score
def evaluate_file(self, detok_hyp: Path, ref: Union[Path, List[str]], lowercase=True) -> float: detok_lines = IO.get_lines(detok_hyp) # takes multiple refs, but here we have only one ref_liness = [IO.get_lines(ref) if isinstance(ref, Path) else ref] bleu: BLEU = corpus_bleu(sys_stream=detok_lines, ref_streams=ref_liness, lowercase=lowercase) # this should be part of new sacrebleu release (i sent a PR ;) bleu_str = f'BLEU = {bleu.score:.2f} {"/".join(f"{p:.1f}" for p in bleu.precisions)}' \ f' (BP = {bleu.bp:.3f} ratio = {(bleu.sys_len / bleu.ref_len):.3f}' \ f' hyp_len = {bleu.sys_len:d} ref_len={bleu.ref_len:d})' bleu_file = detok_hyp.with_suffix(('.lc' if lowercase else '.oc') + '.sacrebleu') log.info(f'BLEU {detok_hyp} : {bleu_str}') IO.write_lines(bleu_file, bleu_str) return bleu.score
def copy_vocabs(self, other): """ Copies vocabulary files from self to other :param other: other experiment :return: """ other: TranslationExperiment = other if not other.data_dir.exists(): other.data_dir.mkdir(parents=True) for source, destination in [ (self._src_field_file, other._src_field_file), (self._tgt_field_file, other._tgt_field_file), (self._shared_field_file, other._shared_field_file) ]: if source.exists(): IO.copy_file(source, destination) src_txt_file = source.with_name( source.name.replace('.model', '.vocab')) if src_txt_file.exists(): dst_txt_file = destination.with_name( destination.name.replace('.model', '.vocab')) IO.copy_file(src_txt_file, dst_txt_file)
def store_model(self, epoch: int, model, train_score: float, val_score: float, keep: int, prefix='model', keeper_sort='step'): """ saves model to a given path :param epoch: epoch number of model :param model: model object itself :param train_score: score of model on training split :param val_score: score of model on validation split :param keep: number of good models to keep, bad models will be deleted :param prefix: prefix to store model. default is "model" :param keeper_sort: criteria for choosing the old or bad models for deletion. Choices: {'total_score', 'step'} :return: """ # TODO: improve this by skipping the model save if the model is not good enough to be saved if self.read_only: log.warning("Ignoring the store request; experiment is readonly") return name = f'{prefix}_{epoch:03d}_{train_score:.6f}_{val_score:.6f}.pkl' path = self.model_dir / name log.info(f"Saving epoch {epoch} to {path}") torch.save(model, str(path)) del_models = [] if keeper_sort == 'total_score': del_models = self.list_models(sort='total_score', desc=False)[keep:] elif keeper_sort == 'step': del_models = self.list_models(sort='step', desc=True)[keep:] else: Exception(f'Sort criteria{keeper_sort} not understood') for d_model in del_models: log.info( f"Deleting model {d_model} . Keep={keep}, sort={keeper_sort}") os.remove(str(d_model)) with IO.writer(os.path.join(self.model_dir, 'scores.tsv'), append=True) as f: cols = [ str(epoch), datetime.now().isoformat(), name, f'{train_score:g}', f'{val_score:g}' ] f.write('\t'.join(cols) + '\n')
def decode_eval_file(self, decoder, src: Union[Path, List[str]], out_file: Path, ref: Optional[Union[Path, List[str]]], lowercase: bool = True, **dec_args) -> float: if out_file.exists() and out_file.stat().st_size > 0 and line_count( out_file) == (len(src) if isinstance(src, list) else line_count(src)): log.warning( f"{out_file} exists and has desired number of lines. Skipped..." ) else: if isinstance(src, Path): log.info(f"decoding {src.name}") src = list(IO.get_lines(src)) if isinstance(ref, Path): ref = list(IO.get_lines(ref)) with IO.writer(out_file) as out: decoder.decode_file(src, out, **dec_args) detok_hyp = self.detokenize(out_file) if ref: return self.evaluate_file(detok_hyp, ref, lowercase=lowercase)
def get_train_data(self, batch_size: Union[int, Tuple[int, int]], steps: int = 0, sort_by='eq_len_rand_batch', batch_first=True, shuffle=False, fine_tune=False, keep_in_mem=False, split_ratio: float = 0., dynamic_epoch=False): data_path = self.train_db if self.train_db.exists( ) else self.train_file if fine_tune: if not self.finetune_file.exists(): # user may have added fine tune file later self._pre_process_parallel('finetune_src', 'finetune_tgt', self.finetune_file) log.info("Using Fine tuning corpus instead of training corpus") data_path = self.finetune_file if split_ratio > 0: data_path = IO.maybe_tmpfs(data_path) train_file = data_path.with_suffix('.db.tmp') file_creator = partial(self.file_creator, train_file=train_file, split_ratio=split_ratio) train_data = GenerativeBatchIterable(file_creator=file_creator, batches=steps, batch_size=batch_size, field=self.tgt_vocab, dynamic_epoch=dynamic_epoch, batch_first=batch_first, shuffle=shuffle, sort_by=sort_by, **self._get_batch_args()) else: data = BatchIterable(data_path=data_path, batch_size=batch_size, field=self.tgt_vocab, sort_by=sort_by, batch_first=batch_first, shuffle=shuffle, **self._get_batch_args()) train_data = LoopingIterable(data, steps) return train_data
def read_all(self) -> Iterator[IdExample]: with IO.reader(self.path) as lines: recs = (line.split('\t') for line in lines) for idx, rec in enumerate(recs): x = self._parse(rec[0].strip()) y = self._parse(rec[1].strip()) if len(rec) > 1 else None if self.truncate: # truncate long recs x = x[:self.max_src_len] y = y if y is None else y[:self.max_tgt_len] elif len(x) > self.max_src_len or (0 if y is None else len(y)) > self.max_tgt_len: continue # skip long recs if not x or (y is not None and len(y) == 0): # empty on one side log.warning( f"Ignoring an empty record x:{len(x)} y:{len(y)}") continue yield IdExample(x, y, id=idx)
def save_embeddings(self, step, train_loss, val_loss, txt=True): matrix = self.model.emb.weight vocab = self.exp.shared_vocab words = [vocab.id_to_piece(i) for i in range(len(vocab))] self.tbd.add_embedding(matrix, metadata=words, global_step=step) ext = 'txt.gz' if txt else 'pkl' path = self.exp.model_dir / f'embeddings_{step}_{train_loss:.6f}_{val_loss:.6f}.{ext}' log.info(f"writing embedding after step {step} to {path}") if txt: with IO.writer(path) as w: w.write(f'{matrix.shape[0]} {matrix.shape[1]}\n') for i in range(matrix.shape[0]): word = words[i] vect = ' '.join(f'{x:g}' for x in matrix[i]) w.write(f'{word} {vect}\n') else: with path.open('wb') as f: data = {'words': words, 'vectors': matrix.numpy} pickle.dump(data, f)
def get_combo_data(self, batch_size: int, steps: int = 0, sort_desc=False, batch_first=True, shuffle=False): if not self.combo_file.exists(): # user may have added fine tune file later self._pre_process_parallel('combo_src', 'combo_tgt', self.combo_file) combo_file = IO.maybe_tmpfs(self.combo_file) data = BatchIterable(combo_file, batch_size=batch_size, sort_desc=sort_desc, field=self.tgt_vocab, batch_first=batch_first, shuffle=shuffle, **self._get_batch_args()) if steps > 0: data = LoopingIterable(data, steps) return data
def validate_args(args, exp: Experiment): if not args.pop('skip_check'): # if --skip-check is not requested assert exp.has_prepared(), \ f'Experiment dir {exp.work_dir} is not ready to train. Please run "prep" sub task' assert exp.has_trained(), \ f'Experiment dir {exp.work_dir} is not ready to decode.' \ f' Please run "train" sub task or --skip-check to ignore this' weights_file = exp.work_dir / 'combo-weights.yml' if not args.get('sys_comb') and weights_file.exists(): log.warning("Found default combo weights, switching to combo mode") args['sys_comb'] = weights_file if args.get("sys_comb"): with IO.reader(args['sys_comb']) as fh: weights = yaml.load(fh)['weights'] args['model_path'], args['weights'] = zip(*weights.items()) for model in args['model_path']: assert Path(model).exists(), model assert abs(sum(args['weights']) - 1) < 1e-3, \ f'Weights from --sys-comb file should sum to 1.0, given={args["weights"]}'
def train(self, steps: int, batch_size: int): log.info(f"Going to train for {steps}") batches = self.exp.get_combo_data(batch_size=batch_size, steps=steps) with tqdm(batches, total=steps, unit='step', dynamic_ncols=True) as data_bar: for i, batch in enumerate(data_bar): batch = batch.to(device) y_probs = self.combo(batch) # B x T x V loss = self.loss_func(y_probs, y_seqs=batch.y_seqs, norm=batch.y_toks) wt_str = ','.join(f'{wt:g}' for wt in self.combo.weight) progress_msg = f'loss={loss:g}, weights={wt_str}' data_bar.set_postfix_str(progress_msg, refresh=False) weights = dict( zip([str(x) for x in self.combo.model_paths], self.combo.model_weights.tolist())) log.info(f" Training finished. {weights}") with IO.writer(self.w_file) as wtr: yaml.dump(dict(weights=weights), wtr, default_flow_style=False)
def __init__(self, path: Union[str, Path]): with IO.reader(path) as rdr: data = yaml.load(rdr) hub_api = self.load_hub_model(data['model_id']) # these are for XML-R wiz RoBERTa from fairseq ; generalize it for other models later self.bpe = hub_api.bpe self.tok2idx = { tok: new_idx for tok, (new_idx, old_idx) in data['mapping'].items() } self.idx2tok = list( sorted(self.tok2idx.keys(), key=self.tok2idx.get, reverse=False)) assert len(self.idx2tok) == len(self.tok2idx) for tok, idx in self.reserved(): # reserved are reserved assert self.tok2idx[tok] == idx assert self.idx2tok[idx] == tok self.new_idx2old_idx = { new_idx: old_idx for tok, (new_idx, old_idx) in data['mapping'].items() }
def read_tsv(path: str): assert os.path.exists(path) with IO.reader(path) as f: yield from (line.split('\t') for line in f)
def tune_decoder_params(self, exp: Experiment, tune_src: str, tune_ref: str, batch_size: int, trials: int = 10, lowercase=True, beam_size=(1, 4, 8), ensemble=(1, 5, 10), lp_alpha=(0.0, 0.4, 0.6), suggested: List[Tuple[int, int, float]] = None, **fixed_args): _, _, _, tune_args = inspect.getargvalues(inspect.currentframe()) tune_args.update(fixed_args) ex_args = ['exp', 'self', 'fixed_args', 'batch_size', 'max_len'] if trials == 0: ex_args += ['beam_size', 'ensemble', 'lp_alpha'] for x in ex_args: del tune_args[x] # exclude some args _, step = exp.get_last_saved_model() tune_dir = exp.work_dir / f'tune_step{step}' log.info(f"Tune dir = {tune_dir}") tune_dir.mkdir(parents=True, exist_ok=True) tune_src, tune_ref = Path(tune_src), Path(tune_ref) assert tune_src.exists() assert tune_ref.exists() tune_src, tune_ref = list(IO.get_lines(tune_src)), list( IO.get_lines(tune_ref)) assert len(tune_src) == len(tune_ref) tune_log = tune_dir / 'scores.json' # resume the tuning memory: Dict[Tuple, float] = {} if tune_log.exists(): data = json.load(tune_log.open()) # JSON keys cant be tuples, so they were stringified memory = {eval(k): v for k, v in data.items()} beam_sizes, ensembles, lp_alphas = [], [], [] if suggested: if isinstance(suggested[0], str): suggested = [eval(x) for x in suggested] suggested = [(x[0], x[1], round(x[2], 2)) for x in suggested] suggested_new = [x for x in suggested if x not in memory] beam_sizes += [x[0] for x in suggested_new] ensembles += [x[1] for x in suggested_new] lp_alphas += [x[2] for x in suggested_new] new_trials = trials - len(memory) if new_trials > 0: beam_sizes += [random.choice(beam_size) for _ in range(new_trials)] ensembles += [random.choice(ensemble) for _ in range(new_trials)] lp_alphas += [ round(random.choice(lp_alpha), 2) for _ in range(new_trials) ] # ensembling is somewhat costlier, so try minimize the model ensembling, by grouping them together grouped_ens = defaultdict(list) for b, ens, l in zip(beam_sizes, ensembles, lp_alphas): grouped_ens[ens].append((b, l)) try: for ens, args in grouped_ens.items(): decoder = Decoder.new(exp, ensemble=ens) for b_s, lp_a in args: eff_batch_size = batch_size // b_s # effective batch size name = f'tune_step{step}_beam{b_s}_ens{ens}_lp{lp_a:.2f}' log.info(name) out_file = tune_dir / f'{name}.out.tsv' score = self.decode_eval_file(decoder, tune_src, out_file, tune_ref, batch_size=eff_batch_size, beam_size=b_s, lp_alpha=lp_a, lowercase=lowercase, **fixed_args) memory[(b_s, ens, lp_a)] = score best_params = sorted(memory.items(), key=lambda x: x[1], reverse=True)[0][0] return dict(zip(['beam_size', 'ensemble', 'lp_alpha'], best_params)), tune_args finally: # JSON keys cant be tuples, so we stringify them data = {str(k): v for k, v in memory.items()} IO.write_lines(tune_log, json.dumps(data))
def write_lines(lines, path): log.info(f"Storing data at {path}") with IO.writer(path) as f: for line in lines: f.write(line) f.write('\n')
def __init__(self, exp: Experiment, model: Optional[NMTModel] = None, model_factory: Optional[Callable] = None, optim: str = 'ADAM', **optim_args): self.start_step = 0 self.last_step = -1 self.exp = exp optim_state = None if model: self.model = model else: args = exp.model_args assert args assert model_factory self.model, args = model_factory(exp=exp, **args) exp.model_args = args last_model, self.last_step = self.exp.get_last_saved_model() if last_model: self.start_step = self.last_step + 1 log.info( f"Resuming training from step:{self.start_step}, model={last_model}" ) state = torch.load(last_model) model_state = state[ 'model_state'] if 'model_state' in state else state if 'optim_state' in state: optim_state = state['optim_state'] self.model.load_state_dict(model_state) else: log.info( "No earlier check point found. Looks like this is a fresh start" ) # making optimizer optim_args['lr'] = optim_args.get('lr', 0.1) optim_args['betas'] = optim_args.get('betas', [0.9, 0.98]) optim_args['eps'] = optim_args.get('eps', 1e-9) warmup_steps = optim_args.pop('warmup_steps', 8000) self._smoothing = optim_args.pop('label_smoothing', 0.1) constant = optim_args.pop('constant', 2) self.model = self.model.to(device) inner_opt = Optims[optim].new(self.model.parameters(), **optim_args) if optim_state: log.info("restoring optimizer state from checkpoint") try: inner_opt.load_state_dict(optim_state) except Exception: log.exception("Unable to restore optimizer, skipping it.") self.opt = NoamOpt(self.model.model_dim, constant, warmup_steps, inner_opt, step=self.start_step) optim_args.update( dict(warmup_steps=warmup_steps, label_smoothing=self._smoothing, constant=constant)) if self.exp.read_only: self.tbd = NoOpSummaryWriter() else: self.tbd = SummaryWriter(log_dir=str(exp.work_dir / 'tensorboard')) self.exp.optim_args = optim, optim_args if not self.exp.read_only: self.exp.persist_state() self.samples = None if exp.samples_file.exists(): with IO.reader(exp.samples_file) as f: self.samples = [line.strip().split('\t') for line in f] log.info(f"Found {len(self.samples)} sample records") if self.start_step == 0: for samp_num, sample in enumerate(self.samples): self.tbd.add_text(f"sample/{samp_num}", " || ".join(sample), 0) from rtg.module.decoder import Decoder self.decoder = Decoder.new(self.exp, self.model) if self.start_step == 0: self.init_embeddings() self.model = self.model.to(device)
def inherit_parent(self): parent = self.config['parent'] parent_exp = TranslationExperiment(parent['experiment'], read_only=True) log.info(f"Parent experiment: {parent_exp.work_dir}") parent_exp.has_prepared() vocab_sepc = parent.get('vocab') if vocab_sepc: log.info(f"Parent vocabs inheritance spec: {vocab_sepc}") codec_lib = parent_exp.config['prep'].get('codec_lib') if codec_lib: self.config['prep']['codec_lib'] = codec_lib def _locate_field_file(exp: TranslationExperiment, name, check_exists=False) -> Path: switch = { 'src': exp._src_field_file, 'tgt': exp._tgt_field_file, 'shared': exp._shared_field_file } assert name in switch, f'{name} not allowed; valid options= {switch.keys()}' file = switch[name] if check_exists: assert file.exists( ), f'{file} doesnot exist; for {name} of {exp.work_dir}' return file for to_field, from_field in vocab_sepc.items(): from_field_file = _locate_field_file(parent_exp, from_field, check_exists=True) to_field_file = _locate_field_file(self, to_field, check_exists=False) IO.copy_file(from_field_file, to_field_file) self.reload_vocabs() else: log.info("No vocabularies are inherited from parent") model_sepc = parent.get('model') if model_sepc: log.info("Parent model inheritance spec") if model_sepc.get('args'): self.model_args = parent_exp.model_args ensemble = model_sepc.get('ensemble', 1) model_paths = parent_exp.list_models(sort='step', desc=True)[:ensemble] log.info( f"Averaging {len(model_paths)} checkpoints of parent model: \n{model_paths}" ) from rtg.module.decoder import Decoder avg_state = Decoder.average_states(model_paths=model_paths) log.info( f"Saving parent model's state to {self.parent_model_state}") torch.save(avg_state, self.parent_model_state) shrink_spec = parent.get('shrink') if shrink_spec: remap_src, remap_tgt = self.shrink_vocabs() def map_rows(mapping: List[int], source: torch.Tensor, name=''): assert max(mapping) < len(source) target = torch.zeros((len(mapping), *source.shape[1:]), dtype=source.dtype, device=source.device) for new_idx, old_idx in enumerate(mapping): target[new_idx] = source[old_idx] log.info(f"Mapped {name} {source.shape} --> {target.shape} ") return target """ src_embed.0.lut.weight [N x d] tgt_embed.0.lut.weight [N x d] generator.proj.weight [N x d] generator.proj.bias [N] """ if remap_src: key = 'src_embed.0.lut.weight' avg_state[key] = map_rows(remap_src, avg_state[key], name=key) if remap_tgt: map_keys = [ 'tgt_embed.0.lut.weight', 'generator.proj.weight', 'generator.proj.bias' ] for key in map_keys: if key not in avg_state: log.warning( f'{key} not found in avg_state of parent model. Mapping skipped' ) continue avg_state[key] = map_rows(remap_tgt, avg_state[key], name=key) if self.parent_model_state.exists(): self.parent_model_state.rename( self.parent_model_state.with_suffix('.orig')) torch.save(avg_state, self.parent_model_state) self.persist_state( ) # this will fix src_vocab and tgt_vocab of model_args conf
def _write_dict(dict, path: Path): with IO.writer(path) as out: for key, val in dict.items(): out.write(f"{key}\t{val}\n")
def load_conf(inp: Union[str, Path]): with IO.reader(inp) as fh: return yaml.load(fh)