def fork_experiment(from_exp: Path, to_exp: Path, conf: bool, vocab: bool, data: bool, code: bool): assert from_exp.exists() log.info(f'Fork: {str(from_exp)} → {str(to_exp)}') if not to_exp.exists(): log.info(f"Create dir {str(to_exp)}") to_exp.mkdir(parents=True) if conf: conf_file = to_exp / 'conf.yml' IO.maybe_backup(conf_file) IO.copy_file(from_exp / 'conf.yml', conf_file) if data: to_data_dir = (to_exp / 'data') from_data_dir = from_exp / 'data' if to_data_dir.is_symlink(): log.info(f"removing the existing data link: {to_data_dir.resolve()}") to_data_dir.unlink() assert not to_data_dir.exists() assert from_data_dir.exists() log.info(f"link {to_data_dir} → {from_data_dir}") to_data_dir.symlink_to(from_data_dir.resolve()) (to_exp / '_PREPARED').touch(exist_ok=True) if not data and vocab: # just the vocab Experiment(from_exp, read_only=True).copy_vocabs( Experiment(to_exp, config={'Not': 'Empty'}, read_only=True)) if code: for f in ['rtg.zip', 'githead']: src = from_exp / f if not src.exists(): log.warning(f"File Not Found: {src}") continue IO.copy_file(src, to_exp / f)
def write(cls, path, records: Iterator[ParallelSeqRecord]): if path.exists(): log.warning(f"Overwriting {path} with new records") os.remove(str(path)) maybe_tmp = IO.maybe_tmpfs(path) log.info(f'Creating {maybe_tmp}') conn = sqlite3.connect(str(maybe_tmp)) cur = conn.cursor() cur.execute(cls.TABLE_STATEMENT) cur.execute(cls.INDEX_X_LEN) cur.execute(cls.INDEX_Y_LEN) cur.execute(f"PRAGMA user_version = {cls.CUR_VERSION};") count = 0 for x_seq, y_seq in records: # use numpy. its a lot efficient if not isinstance(x_seq, np.ndarray): x_seq = np.array(x_seq, dtype=np.int32) if y_seq is not None and not isinstance(y_seq, np.ndarray): y_seq = np.array(y_seq, dtype=np.int32) values = (x_seq.tobytes(), None if y_seq is None else y_seq.tobytes(), len(x_seq), len(y_seq) if y_seq is not None else -1) cur.execute(cls.INSERT_STMT, values) count += 1 cur.close() conn.commit() if maybe_tmp != path: # bring the file back to original location where it should be IO.copy_file(maybe_tmp, path) log.info(f"stored {count} rows in {path}")
def forward(self, x, score=None, gen_probs=True, log_probs=True): """ :param x: features or hidden states :param score: what scores are do you want in return? Your options are 'logits' -- scores without any normalization 'softmax' -- raw probs for multi class 'log_softmax' -- log probs for multiclass 'sigmoid' -- for multilabel task :param gen_probs: (deprecated, use 'score=logits') False to get logits; default is True :param log_probs: (deprecated, use score='log_softmax' or 'softmax'). False to get raw probs from softmax, True to get probs from log_softmax. :return: scores based on choice of score=xxx """ # made this mess to preserve backward compatibility if not score: score = 'logits' if gen_probs: score = 'log_softmax' if log_probs else 'softmax' warn_msg = f'API deprecated. use "score={score}" attribute.' if warn_msg not in self.warn_msgs: # warn only Once self.warn_msgs.add(warn_msg) log.warning(warn_msg) traceback.print_stack(limit=6) assert score in self.scores, f'{self.scores.keys()} supported but given "{score}"' if score == 'embedding' or score == 'identity': return x x = self.proj(x) return self.scores[score](x, dim=-1)
def get_stats(data, n, limit=-1): n_seqs = 0 lens = [] freqs = np.zeros(n, dtype=np.int32) for seq in tqdm(data): n_seqs += 1 for i in seq: freqs[i] += 1 lens.append(len(seq)) if limit > 0 and n_seqs > limit: log.warning( f"Aborting at {n_seqs} records though here are more in the dataset" ) break lens = np.array(lens) total_toks = np.sum(lens) assert total_toks == np.sum(freqs) # sanity # Zkip zero frequencies; they could be reserved words or from the other side if vocab is shared n_zero_types = sum(1 for f in freqs if f == 0) n_effective = n - n_zero_types probs = freqs / total_toks imbalance = 0.5 * np.sum(np.abs(1 / n_effective - probs)) res = dict(n_seqs=n_seqs, total_toks=total_toks, mean_len=np.mean(lens), median_len=np.median(lens), max_len=np.max(lens), EMD=imbalance, n=n, zero_types=n_zero_types, effective_n=n_effective) return freqs, res
def main(): args = parse_args() seed = args.pop("seed") if seed: log.info(f"Seed for random number generator: {seed}") import random import torch random.seed(seed) torch.manual_seed(seed) work_dir = Path(args.pop('work_dir')) is_big = load_conf(work_dir / 'conf.yml').get('spark', {}) if is_big: log.info("Big experiment mode enabled; checking pyspark backend") try: import pyspark except: log.warning("unable to import pyspark. Please do 'pip install pyspark' and run again") raise from rtg.big.exp import BigTranslationExperiment exp = BigTranslationExperiment(work_dir=work_dir) else: exp = Experiment(work_dir=work_dir) assert exp.has_prepared(), f'Experiment dir {exp.work_dir} is not ready to train. ' \ f'Please run "prep" sub task' exp.train(args)
def train(self, steps: int, check_point: int, batch_size: int, check_pt_callback: Optional[Callable] = None, side='tgt', ctx_size=2, **args): log.info(f"using side={side}, ctx_size={ctx_size}") reader = DataReader(self.exp, side=side) rem_steps = steps - self.start_step if rem_steps <= 0: log.warning(f"Already trained upto {self.start_step-1}. Skipped") return train_data = reader.get_training_data(batch_size=batch_size, n_batches=rem_steps, ctx_size=ctx_size) val_data = reader.get_val_data(batch_size=batch_size, ctx_size=ctx_size) train_loss, n = 0.0, 0 def _make_checkpt(step, train_loss): with torch.no_grad(): val_loss = self.run_valid_epoch(val_data) log.info(f"Checkpoint at {step}") self.save_embeddings(step, train_loss, val_loss, txt=True) self.tbd.add_scalars('losses', { 'training': train_loss, 'valid_loss': val_loss }, global_step=step) with tqdm(train_data, initial=self.start_step, total=rem_steps + 1, unit='batch', dynamic_ncols=True) as data_bar: for i, (xs, ys) in enumerate(data_bar, start=self.start_step): self.model.zero_grad() xs, ys = xs.to(device), ys.to(device) log_probs = self.model(xs) loss = self.loss_func(log_probs, ys) self.tbd.add_scalars('training', { 'step_loss': loss.item(), 'learn_rate': self.opt.curr_lr }, self.opt.curr_step) progress_msg = f', loss={loss:g} LR={self.opt.curr_lr:g}' data_bar.set_postfix_str(progress_msg, refresh=False) train_loss += loss.item() n += len(ys) loss.backward() self.opt.step() self.opt.zero_grad() if i > 0 and i % check_point == 0: _make_checkpt(i, train_loss / n) train_loss, n = 0.0, 0 if n > 0: _make_checkpt(steps, train_loss / n)
def make_model(cls, src_vocab, tgt_vocab, enc_layers=6, dec_layers=6, hid_size=512, ff_size=2048, n_heads=8, attn_bias=True, attn_dropout=0.1, dropout=0.2, activation='relu', enc_depth_probs: List[float] = (1.0, 0.9, 0.8, 0.7, 0.6, 0.5), dec_depth_probs: List[float] = (1.0, 0.9, 0.8, 0.7, 0.6, 0.5), tied_emb='three-way', exp: Experiment = None): """Helper: Construct a model from hyper parameters.""" assert len(enc_depth_probs) == enc_layers assert len(dec_depth_probs) == dec_layers # get all args for reconstruction at a later phase args = get_my_args(exclusions=['cls', 'exp']) assert activation in {'relu', 'elu', 'gelu'} log.info(f"Make model, Args={args}") c = copy.deepcopy attn = tfm.MultiHeadedAttention(n_heads, hid_size, dropout=attn_dropout, bias=attn_bias) ff = tfm.PositionwiseFeedForward(hid_size, ff_size, dropout, activation=activation) if enc_layers == 0: log.warning("Zero encoder layers!") encoder = SkipEncoder( tfm.EncoderLayer(hid_size, c(attn), c(ff), dropout), enc_layers, enc_depth_probs) assert dec_layers > 0 decoder = SkipDecoder( tfm.DecoderLayer(hid_size, c(attn), c(attn), c(ff), dropout), dec_layers, dec_depth_probs) src_emb = nn.Sequential(tfm.Embeddings(hid_size, src_vocab), tfm.PositionalEncoding(hid_size, dropout)) tgt_emb = nn.Sequential(tfm.Embeddings(hid_size, tgt_vocab), tfm.PositionalEncoding(hid_size, dropout)) generator = tfm.Generator(hid_size, tgt_vocab) model = cls(encoder, decoder, src_emb, tgt_emb, generator) if tied_emb: model.tie_embeddings(tied_emb) model.init_params() return model, args
def enable_fp16(self): if not self.fp16: # conditional import self.fp16 = True self._scaler = GradScaler(enabled=self.fp16) log.info("Enabling FP16 /Automatic Mixed Precision training") else: log.warning(" fp16 is already enabled")
def from_lines(cls, lines: Iterator[str], batch_size: int, vocab: Field, sort=True, max_src_len=0, max_len_buffer=0): """ Note: this changes the order based on sequence length if sort=True :param lines: stream of input lines :param batch_size: number of tokens in batch :param vocab: Field to use for mapping word pieces to ids :param sort: sort based on descending order of length :param max_src_len : truncate at length ; 0 disables this :return: stream of DecoderBatches """ log.info("Tokenizing sequences") buffer = [] for i, line in enumerate(lines): line = line.strip() if not line: log.warning( f"line {i + 1} was empty. inserting a dot (.). " f"Empty lines are problematic when you want line-by-line alignment..." ) line = "." cols = line.split('\t') id, ref = None, None if len(cols) == 1: # SRC src = cols[0] elif len(cols) == 2: # ID \t SRC id, src = cols else: # ID \t SRC \t REF id, src, ref = cols[:3] seq = vocab.encode_as_ids(src, add_eos=True, add_bos=False) if max_src_len > 0 and len(seq) > max_src_len: log.warning( f"Line {i} full length={len(seq)} ; truncated to {max_src_len}" ) seq = seq[:max_src_len] buffer.append((i, src, ref, seq, id)) # idx, src, ref, seq, id if sort: log.info(f"Sorting based on the length. total = {len(buffer)}") buffer = sorted(buffer, reverse=True, key=lambda x: len(x[3])) # sort by length of seq batch = cls() batch.max_len_buffer = max_len_buffer for idx, src, ref, seq, _id in buffer: batch.add(idx=idx, src=src, ref=ref, seq=seq, id=_id) if batch.padded_tok_count >= batch_size: yield batch batch = cls() batch.max_len_buffer = max_len_buffer if batch.line_count > 0: yield batch
def make_model( cls, src_vocab, tgt_vocab, enc_layers=6, dec_layers=6, hid_size=512, n_heads=8, attn_dropout=0.1, dropout=0.2, activation='relu', eff_dims: List[int] = (1024, 1024, 2048, 2048, 1024, 1024), # Using tuple for immutability dff_dims: List[int] = (1024, 1024, 2048, 2048, 1024, 1024), tied_emb='three-way', exp: Experiment = None): """Helper: Construct a model from hyper parameters.""" assert enc_layers == len(eff_dims) assert dec_layers == len(dff_dims) # get all args for reconstruction at a later phase args = get_my_args(exclusions=['cls', 'exp']) assert activation in {'relu', 'elu', 'gelu'} log.info(f"Make model, Args={args}") if enc_layers == 0: log.warning("Zero encoder layers!") encoder = WidthVaryingEncoder(d_model=hid_size, ff_dims=eff_dims, N=enc_layers, n_heads=n_heads, attn_dropout=attn_dropout, dropout=dropout, activation=activation) assert dec_layers > 0 decoder = WidthVaryingDecoder(d_model=hid_size, ff_dims=dff_dims, N=dec_layers, n_heads=n_heads, attn_dropout=attn_dropout, dropout=dropout, activation=activation) src_emb = nn.Sequential(Embeddings(hid_size, src_vocab), PositionalEncoding(hid_size, dropout)) tgt_emb = nn.Sequential(Embeddings(hid_size, tgt_vocab), PositionalEncoding(hid_size, dropout)) generator = Generator(hid_size, tgt_vocab) model = cls(encoder, decoder, src_emb, tgt_emb, generator) if tied_emb: model.tie_embeddings(tied_emb) model.init_params() return model, args
def backward(self, loss): if torch.isnan(loss): log.warning('loss is nan; backward() skipped') return if self.fp16: loss = self._scaler.scale(loss) # to apply norm: TODO: unscale gradients ; refer to docs # torch.nn.utils.clip_grad_norm_(self._amp.master_params(opt.optimizer), self.max_norm) loss.backward()
def make_check_point(self, train_loss: float, val_loss: float, keep_models: int, log_embedding=False): """ Check point the model :param train_loss: training loss value :param val_loss: loss on validation set :param keep_models: how many checkpoints to keep on file system :return: """ step_num = self.opt.curr_step if step_num == self.last_step: log.warning("Ignoring checkpt request") return # calling multiple times doesnt save log.info( f"Checkpoint at optimizer step {step_num}. Training Loss {train_loss:g}," f" Validation Loss:{val_loss:g}") self.show_samples() self.tbd.add_scalars(f'losses', { 'train_loss': train_loss, 'valid_loss': val_loss }, step_num) if log_embedding: # TODO: add metadata (text) of each subword # TODO: Update tag to include tie configuration self.tbd.add_embedding(self.model.generator.proj.weight, global_step=step_num, tag=f'Target embeddings') # Unwrap model state from DataParallel and persist model = (self.model.module if hasattr(self.model, 'module') else self.model) state = { 'model_state': model.state_dict(), 'optim_state': self.opt.optimizer.state_dict(), 'step': step_num, 'train_loss': train_loss, 'val_loss': val_loss, 'time': time.time(), 'rtg_version': rtg.__version__, 'model_type': self.exp.model_type, 'model_args': self.exp.model_args, } if dtorch.fp16: state['amp_state'] = dtorch._scaler.state_dict() self.exp.store_model(step_num, state, train_score=train_loss, val_score=val_loss, keep=keep_models) self.last_step = step_num
def _pre_process_parallel(self, src_key: str, tgt_key: str, out_file: Path, args: Optional[Dict[str, Any]] = None, line_check=False, **kwargs): """ Pre process records of a parallel corpus :param args: all arguments for 'prep' task :param src_key: key that contains source sequences :param tgt_key: key that contains target sequences :param out_file: path to store processed TSV data (compresses if name ends with .gz) :return: """ if kwargs: log.warning(f"The following args are ignored:{kwargs}") if not out_file.name.endswith(".nldb"): if 'train' in out_file.name: log.warning(f"set .nldb extension to enable spark") return super()._pre_process_parallel(src_key=src_key, tgt_key=tgt_key, out_file=out_file, args=args, line_check=line_check) args = args if args else self.config['prep'] log.info(f"Going to prep files {src_key} and {tgt_key}") assert src_key in args, f'{src_key} not found in experiment config or args' assert tgt_key in args, f'{tgt_key} not found in experiment config or args' with log_resources(f"create {out_file.name}"): with spark_session(config=self.spark_conf) as spark: rdd, total = read_raw_parallel_recs( spark, src_path=args[src_key], tgt_path=args[tgt_key], truncate=args['truncate'], src_len=args['src_len'], tgt_len=args['tgt_len'], src_tokenizer=self.src_vocab.encode_as_ids, tgt_tokenizer=self.tgt_vocab.encode_as_ids) id_rdd = rdd.map(lambda r: (r[0], (r[1], r[2]))) # (id, (x, y)) max_part_size = args.get('max_part_size', 1_000_000) n_parts = math.ceil(total / max_part_size) log.info( f"Writing to {out_file}; {n_parts} parts," f" not exceeding {max_part_size:,} records in each part") rdd_as_db(id_rdd, db_path=out_file, field_names=['x', 'y'], overwrite=True, repartition=n_parts)
def maybe_init_from_parent(self, exp: 'TranslationExperiment'): if exp.parent_model_state.exists(): log.info("YES Initialising from a parent model") device = next(self.parameters()).device # device of self model state = torch.load(exp.parent_model_state, map_location=device) error = self.load_state_dict(state, strict=False) log.info("YES Initialized from the parent model") if error.missing_keys or error.unexpected_keys: log.warning(f"Error keys: {error}") else: log.info("NOT initialising from parent model")
def make_model(cls, src_vocab, tgt_vocab, enc_layers=6, dec_layers=6, hid_size=512, ff_size=2048, n_heads=8, dropout=0.1, tied_emb='three-way', activation='relu', exp: Experiment = None): "Helper: Construct a model from hyper parameters." # get all args for reconstruction at a later phase _, _, _, args = inspect.getargvalues(inspect.currentframe()) for exclusion in ['cls', 'exp']: del args[exclusion] # exclude some args # In case you are wondering, why I didnt use **kwargs here: # these args are read from conf file where user can introduce errors, so the parameter # validation and default value assignment is implicitly done by function call for us :) assert activation in {'relu', 'elu', 'gelu'} log.info(f"Make model, Args={args}") c = copy.deepcopy attn = MultiHeadedAttention(n_heads, hid_size, dropout=dropout) ff = PositionwiseFeedForward(hid_size, ff_size, dropout, activation=activation) if enc_layers == 0: log.warning("Zero encoder layers!") encoder = Encoder(EncoderLayer(hid_size, c(attn), c(ff), dropout), enc_layers) assert dec_layers > 0 decoder = Decoder( DecoderLayer(hid_size, c(attn), c(attn), c(ff), dropout), dec_layers) src_emb = nn.Sequential(Embeddings(hid_size, src_vocab), PositionalEncoding(hid_size, dropout)) tgt_emb = nn.Sequential(Embeddings(hid_size, tgt_vocab), PositionalEncoding(hid_size, dropout)) generator = Generator(hid_size, tgt_vocab) model = cls(encoder, decoder, src_emb, tgt_emb, generator) if tied_emb: model.tie_embeddings(tied_emb) model.init_params() return model, args
def safe_delete(cls, path: Path): try: if path.exists(): if path.is_file(): log.info(f"Delete file {path}") path.unlink() elif path.is_dir(): log.info(f"Delete dir {path}") path.rmdir() else: log.warning(f"Coould not delete {path}") except: log.exception(f"Error while clearning up {path}")
def __iter__(self) -> Iterator[IdExample]: for d in self.read_all(): id, x, y = d['id'], d['x'], d.get('y') if x is None or y is None or len(x) == 0 or len(y) == 0: log.warning( f"Ignoring an empty record x:{len(x)} y:{len(y)}") continue if len(x) > self.max_src_len or len(y) > self.max_tgt_len: if self.truncate: x, y = x[:self.max_src_len], y[:self.max_tgt_len] else: # skip this record continue yield IdExample(x=x, y=y, id=id)
def __iter__(self) -> Iterator[Example]: for d in self.read_all(): x, y = d['x'], d.get('y') if not x or not y: log.warning( f"Ignoring an empty record x:{len(x)} y:{len(y)}") continue if len(x) > self.max_src_len or len(y) > self.max_tgt_len: if self.truncate: x, y = x[:self.max_src_len], y[:self.max_tgt_len] else: # skip this record continue yield Example(x, y)
def __init__(self, model: TransformerNMT, field, x_seqs, x_lens=None, multi_label=False): super().__init__(model, field) self.x_mask = (x_seqs != field.pad_idx).unsqueeze(1) self.memory = self.model.encode(x_seqs, self.x_mask) self.multi_label = multi_label if multi_label and not type(self).multi_label_warned: log.warning(">>> Multi-label decoding mode enabled") type(self).multi_label_warned = True
def run_command(cmd_line: str, fail_on_error=True): log.info(f'RUN:: {cmd_line}') proc = subprocess.run(cmd_line, shell=True, capture_output=True, text=True) if proc.returncode == 0: log.info(f"STDOUT={proc.stdout}") if proc.stderr: log.warning(f"STDERR={proc.stderr}") else: msg = f'CMD={cmd_line}\nCODE={proc.returncode}\nSTDOUT={proc.stdout}\nSTDERR={proc.stderr}' if fail_on_error: raise Exception('subprocess failed\n' + msg) else: log.warning(msg) return proc.returncode == 0
def store_model(self, epoch: int, model, train_score: float, val_score: float, keep: int, prefix='model', keeper_sort='step'): """ saves model to a given path :param epoch: epoch number of model :param model: model object itself :param train_score: score of model on training split :param val_score: score of model on validation split :param keep: number of good models to keep, bad models will be deleted :param prefix: prefix to store model. default is "model" :param keeper_sort: criteria for choosing the old or bad models for deletion. Choices: {'total_score', 'step'} :return: """ # TODO: improve this by skipping the model save if the model is not good enough to be saved if self.read_only: log.warning("Ignoring the store request; experiment is readonly") return name = f'{prefix}_{epoch:03d}_{train_score:.6f}_{val_score:.6f}.pkl' path = self.model_dir / name log.info(f"Saving epoch {epoch} to {path}") torch.save(model, str(path)) del_models = [] if keeper_sort == 'total_score': del_models = self.list_models(sort='total_score', desc=False)[keep:] elif keeper_sort == 'step': del_models = self.list_models(sort='step', desc=True)[keep:] else: Exception(f'Sort criteria{keeper_sort} not understood') for d_model in del_models: log.info( f"Deleting model {d_model} . Keep={keep}, sort={keeper_sort}") os.remove(str(d_model)) with IO.writer(os.path.join(self.model_dir, 'scores.tsv'), append=True) as f: cols = [ str(epoch), datetime.now().isoformat(), name, f'{train_score:g}', f'{val_score:g}' ] f.write('\t'.join(cols) + '\n')
def main(): # No grads required torch.set_grad_enabled(False) args = parse_args() gen_args = {} exp = Experiment(args.pop('work_dir'), read_only=True) validate_args(args, exp) if exp.model_type == 'binmt': if not args.get('path'): Exception('--binmt-path argument is needed for BiNMT model.') gen_args['path'] = args.pop('binmt_path') weights = args.get('weights') if weights: decoder = Decoder.combo_new(exp, model_paths=args.pop('model_path'), weights=weights) else: decoder = Decoder.new(exp, gen_args=gen_args, model_paths=args.pop('model_path', None), ensemble=args.pop('ensemble', 1)) if args.pop('interactive'): if weights: log.warning( "Interactive shell not reloadable for combo mode. FIXME: TODO:" ) if args['input'] != sys.stdin or args['output'] != sys.stdout: log.warning( '--input and --output args are not applicable in --interactive mode' ) args.pop('input') args.pop('output') while True: try: # an hacky way to unload and reload model when user tries to switch models decoder.decode_interactive(**args) break # exit loop if there is no request for reload except ReloadEvent as re: decoder = Decoder.new(exp, gen_args=gen_args, model_paths=re.model_paths) args = re.state # go back to loop and redo interactive shell else: return decoder.decode_file(args.pop('input'), args.pop('output'), **args)
def train(self, steps: int, check_point: int, batch_size: int, check_pt_callback: Optional[Callable] = None, **args): train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) if self.start_step >= steps: log.warning(f"Already trained to {self.start_step}. Considering it as done.") return rem_steps = steps - self.start_step side = 'tgt' # TODO: this should be inferrable or configurable instead of hardcoded train_data = self.exp.get_mono_data('train', side, batch_size=batch_size, batch_first=True, sort_dec=True, num_batches=rem_steps, shuffle=True) val_data = self.exp.get_mono_data('valid', side, batch_size=batch_size, batch_first=True, sort_dec=True) keep_models = 8 unsaved_state = False with tqdm(train_data, initial=self.start_step, total=steps, unit='batch', dynamic_ncols=True) as data_bar: for batch in data_bar: batch.to(device) outp_log_probs = self.model.batch_forward(batch) loss = self.simple_loss_func(outp_log_probs, seq_lens=batch.x_len, tot_toks=batch.x_toks, max_seq_len=batch.max_x_len, train_mode=True) unsaved_state = True bar_msg, is_check_pt = train_state.step(batch.x_toks, loss) data_bar.set_postfix_str(bar_msg, refresh=True) del batch # TODO: force free memory if is_check_pt: train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True) unsaved_state = False log.info("End of training session") if unsaved_state: # End of training train_loss = train_state.reset() train_state.train_mode(False) val_loss = self.run_valid_epoch(val_data) self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
def setup_jdbc(): # add postgres jdbc driver import pyspark PG_JDBC_URL = 'https://jdbc.postgresql.org/download/postgresql-42.2.14.jar' jars_dir = Path(pyspark.__file__).parent / 'jars' if jars_dir.exists(): pg_jdbc_jars = list(jars_dir.glob("postgresql-*.jar")) if pg_jdbc_jars: log.info(f'Located JDBC jar for postgres: {pg_jdbc_jars}') else: jar_path = jars_dir / (PG_JDBC_URL.split('/')[-1]) download_file(PG_JDBC_URL, jar_path, fair_on_error=False) else: log.warning( "pyspark jars are not detected. " "You may need to manually configure postgres JDBC to spark config")
def __init__(self, d_model, d_ff, dropout=0.1, activation='relu'): super().__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) activations = dict(relu=F.relu, elu=F.elu) if activation == 'gelu': # TODO: when torch 1.2 comes out, simplify this block try: activations['gelu'] = F.gelu except: log.warning( f"gelu is not available. using torch {torch.__version__}." f" GELU was added in https://github.com/pytorch/pytorch/pull/20665" ) raise self.activation = activations[activation]
def read_all(self) -> Iterator[IdExample]: with IO.reader(self.path) as lines: recs = (line.split('\t') for line in lines) for idx, rec in enumerate(recs): x = self._parse(rec[0].strip()) y = self._parse(rec[1].strip()) if len(rec) > 1 else None if self.truncate: # truncate long recs x = x[:self.max_src_len] y = y if y is None else y[:self.max_tgt_len] elif len(x) > self.max_src_len or (0 if y is None else len(y)) > self.max_tgt_len: continue # skip long recs if not x or (y is not None and len(y) == 0): # empty on one side log.warning( f"Ignoring an empty record x:{len(x)} y:{len(y)}") continue yield IdExample(x, y, id=idx)
def _prep_file(file_key, out_file, do_truncate, max_len, field: Field): if file_key not in args: log.warning( f'Skipped: {file_key} because it is not found in config') return raw_file = args[file_key] recs = TSVData.read_raw_mono_recs(raw_file, do_truncate, max_len, field.encode_as_ids) # TODO: use SQLite storage TSVData.write_mono_recs(recs, out_file) if args.get('text_files'): recs = TSVData.read_raw_mono_recs(raw_file, do_truncate, max_len, field.tokenize) TSVData.write_mono_recs( recs, str(out_file).replace('.tsv', '.pieces.tsv'))
def make_check_point(self, train_loss: float, val_loss: float, keep_models: int): """ Check point the model :param train_loss: training loss value :param val_loss: loss on validation set :param keep_models: how many checkpoints to keep on file system :return: """ step_num = self.opt.curr_step if step_num == self.last_step: log.warning("Ignoring checkpt request") return # calling multiple times doesnt save log.info( f"Checkpoint at step {step_num}. Training Loss {train_loss:g}," f" Validation Loss:{val_loss:g}") self.show_samples() self.tbd.add_scalars(f'losses', { 'train_loss': train_loss, 'valid_loss': val_loss }, step_num) # Unwrap model state from DataParallel and persist model = (self.model.module if isinstance(self.model, nn.DataParallel) else self.model) state = { 'model_state': model.state_dict(), 'optim_state': self.opt.optimizer.state_dict(), 'step': step_num, 'train_loss': train_loss, 'val_loss': val_loss, 'time': time.time(), 'rtg_version': rtg.__version__, 'model_type': self.exp.model_type, 'model_args': self.exp.model_args, } self.exp.store_model(step_num, state, train_score=train_loss, val_score=val_loss, keep=keep_models) self.last_step = step_num
def main(): args = parse_args() conf_file: Path = args.conf_file if args.conf_file else args.work_dir / 'conf.yml' assert conf_file.exists() ExpFactory = TranslationExperiment is_big = load_conf(conf_file).get('spark', {}) if is_big: log.info("Big experiment mode enabled; checking pyspark backend") try: import pyspark log.info("pyspark is available") except: log.warning("unable to import pyspark. Please do 'pip install pyspark' and run again") raise from rtg.big.exp import BigTranslationExperiment ExpFactory = BigTranslationExperiment exp = ExpFactory(args.exp, config=conf_file, read_only=False) return exp.pre_process()
def from_lines(cls, lines: Iterator[str], batch_size: int, vocab: Field, sort=True): """ Note: this changes the order based on sequence length if sort=True :param lines: stream of input lines :param batch_size: number of tokens in batch :param vocab: Field to use for mapping word pieces to ids :param sort: sort based on descending order of length :return: stream of DecoderBatches """ log.info("Tokenizing sequences") buffer = [] for i, line in enumerate(lines): line = line.strip() if not line: log.warning( f"line {i + 1} was empty. inserting a dot (.). " f"Empty lines are problematic when you want line-by-line alignment..." ) line = "." cols = line.split('\t') seq = vocab.encode_as_ids(line, add_eos=True, add_bos=False) ref = cols[1] if len(cols) > 1 else None buffer.append((i, cols[0], ref, seq)) #idx, src, ref, seq if sort: log.info(f"Sorting based on the length. total = {len(buffer)}") buffer = sorted(buffer, reverse=True, key=lambda x: len(x[-1])) # sort by length of seq batch = cls() for idx, src, ref, seq in buffer: batch.add(idx=idx, src=src, ref=ref, seq=seq) if batch.padded_tok_count >= batch_size: yield batch batch = cls() if batch.line_count > 0: yield batch