def main(): args = parse_args() seed = args.pop("seed") if seed: log.info(f"Seed for random number generator: {seed}") import random import torch random.seed(seed) torch.manual_seed(seed) exp = Experiment(args.pop('work_dir')) assert exp.has_prepared(), f'Experiment dir {exp.work_dir} is not ready to train. ' \ f'Please run "prep" sub task' _, optim_args = exp.optim_args if optim_args is None: optim_args = {} if args.get('optim_args'): # convert key1=val1,key2=val2 format to dictionary pairs = [x.strip() for x in args.pop('optim_args').split(',')] pairs = [pair.split('=') for pair in pairs if pair] optim_args.update({k.strip(): float(v) for k, v in pairs}) trainer = { 'TRANSFORMER': HieroTransformerTrainer, 'HRED': SteppedHREDTrainer, }[exp.model_type](exp, optim=args.pop('optim'), **optim_args) try: trainer.train(**args) except RuntimeError as e: if 'out of memory' in str(e).lower(): log_tensor_sizes() raise e
def decode_dialogs(self, dialogs: Iterator[Dialog], out, verbose=True, **args): min_ctx, max_ctx = self.exp.min_ctx, self.exp.max_ctx test_chars = None for i, dialog in enumerate(dialogs): if out: # write out the context for utter in dialog.chat[:min_ctx]: line = "CTX\t" # this is a context line += f"{utter.uid}\t" if utter.uid else "" line += f"{utter.raw_char}\t{utter.raw_text}\n" out.write(line) chats: Iterator[ChatRec] = dialog.as_test_chats(min_ctx=min_ctx, max_ctx=max_ctx, test_chars=test_chars) for j, chat in enumerate(chats): # One chat in batch. Should/can be improved later batch = chat.as_dialog_mini_batch() result = self.generate_chat(batch, **args) if verbose: log.info(f"dialog: {i}: chat: {j} :: \n" f"MSG: {chat.context[-1].raw_char}: {chat.context[-1].raw_text}\n" f"RSP: {chat.response.raw_char}: {chat.response.raw_text}") out_line = '\n'.join(f'{hyp}\t{score:.4f}' for score, hyp in result) log.info(f"OUT:\n{out_line} \n") if out: resp = chat.response line = f"GEN\t" # Generated line += f"{resp.uid}\t" if resp.uid else "" line += f"{resp.raw_char}\t{resp.raw_text}\t" # Reference Text line += "\t".join([f"{hyp}\t{score:g}" for score, hyp in result]) line += "\n" out.write(line) out.write("\n") # dialog seperator
def read_all(path: Path, add_eos): count = 0 with IO.reader(path) as reader: dialog = Dialog() for line in reader: line = line.strip() if line: parts = line.split("\t") char, seq = parts[-2:] # the last two are mandatory uid = parts[0] if len(parts) > 2 else None weight = float(parts[1]) if len(parts) > 3 else None char, seq = int(char), [ int(x) for x in seq.strip().split() ] if add_eos and seq[-1] != EOS_TOK_IDX: seq.append(EOS_TOK_IDX) dialog.append(Utterance(char, seq, uid=uid, weight=weight)) else: if len(dialog) > 0: yield dialog count += 1 dialog = Dialog() if len(dialog) > 0: count += 1 yield dialog log.info(f"Read {count} dialogs")
def write_tsv(records: Iterator[DialogRecord], path: Union[str, Path]): seqs = ((str(x), ' '.join(map(str, y))) for x, y in records) lines = (f'{x}\t{y}\n' for x, y in seqs) log.info(f"Storing data at {path}") with IO.writer(path) as f: for line in lines: f.write(line)
def store_model(self, step: int, model, train_score: float, val_score: float, keep: int): """ saves model to a given path :param step: step number of training :param model: model object itself :param train_score: score of model on training split :param val_score: score of model on validation split :param keep: number of good models to keep, bad models will be deleted :return: """ # TODO: improve this by skipping the model save if the model is not good enough to be saved if self.read_only: log.warning("Ignoring the store request; experiment is readonly") return name = f'model_{step:05d}_{train_score:.6f}_{val_score:.6f}.pkl' path = self.model_dir / name log.info(f"Saving... step={step} to {path}") torch.save(model, str(path)) for bad_model in self.list_models(sort='total_score', desc=False)[keep:]: log.info(f"Deleting bad model {bad_model} . Keep={keep}") os.remove(str(bad_model)) with IO.writer(os.path.join(self.model_dir, 'scores.tsv'), append=True) as f: cols = [ str(step), datetime.now().isoformat(), name, f'{train_score:g}', f'{val_score:g}' ] f.write('\t'.join(cols) + '\n')
def __init__(self, exp: DialogExperiment, model: Optional[HieroTransformer] = None, optim: str = 'ADAM', **optim_args): super().__init__(exp, model, model_factory=HieroTransformer.make_model, optim=optim, **optim_args) device_ids = list(range(torch.cuda.device_count())) log.info( f"Going to use {torch.cuda.device_count()} GPU(s) ; ids:{device_ids}" ) if len(device_ids) > 1: # Multi GPU mode raise Exception("Multi GPU mode not supported yet") generator = self.model.generator criterion = LabelSmoothing(vocab_size=generator.vocab, padding_idx=PAD_IDX, smoothing=self._smoothing) self.loss_func = SimpleLossFunction(generator, criterion, opt=self.opt)
def train(model_type: str, vocab_size: int, model_path: str, files: Iterator[str], no_split_toks: Optional[List[str]] = None): """ Train Sentence Piece Model :param model_type: sentence piece model type: {unigram, BPE, word, char} :param vocab_size: target vocabulary size :param model_path: where to store model :param files: input files :param no_split_toks: Don't split these tokens :return: """ model_prefix = model_path.replace('.model', '') files = set(files) # remove duplicates arg = f"--input={','.join(files)} --vocab_size={vocab_size} --model_prefix={model_prefix}" \ f" --model_type={model_type} --pad_id={PAD_TOK[1]} --bos_id={BOS_TOK[1]}" \ f" --eos_id={EOS_TOK[1]} --unk_id={UNK_TOK[1]} --hard_vocab_limit=false" if no_split_toks: arg += f" --user_defined_symbols={','.join(no_split_toks)}" log.info(f"SPM: {arg}") SentencePieceTrainer.Train(arg) log.info("Training complete") if not model_path.endswith('.model'): model_path += '.model' return Field(model_path)
def __init__(self, utter_encoder: Encoder, ctx_encoder: Encoder, decoder: Decoder, src_inp_embs: ComboEmbeddings, tgt_inp_embs: ComboEmbeddings, generator: Generator, dropout: float, sent_repr_mode: str = 'cls'): super().__init__() self.utter_encoder = utter_encoder self.ctx_encoder = ctx_encoder self.decoder = decoder self.src_inp_embs: ComboEmbeddings = src_inp_embs self.tgt_inp_embs: ComboEmbeddings = tgt_inp_embs self.generator = generator self._model_dim = generator.d_model assert sent_repr_mode in ('sum', 'cls') log.info(f"Sentence Representation mode :: {sent_repr_mode}") self.sent_repr_mode = sent_repr_mode if sent_repr_mode == 'sum': log.warning( "warning: summing the vectors didn't help in previous runs") self.sent_repr_conn = SublayerConnection(self._model_dim, dropout) # positional encoder for the chat sequence self.posit_enc = PositionalEncoding(self._model_dim, dropout=dropout)
def train(self, steps: int, check_point: int, check_pt_callback: Optional[Callable] = None, fine_tune=False, **args): log.info(f'Going to train for {steps} epochs; ' f'check point size:{check_point}; fine_tune={fine_tune}') keep_models = args.get('keep_models', 4) # keep last _ models and delete the old if steps <= self.start_step: raise Exception( f'The model was already trained to {self.start_step} steps. ' f'Please increase the steps or clear the existing models') train_data = self.exp.get_train_data(loop_steps=steps - self.start_step, fine_tune=fine_tune, sort_dec=False) val_data = self.exp.get_val_data(sort_dec=False) train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) with tqdm(train_data, initial=self.start_step, total=steps, unit='batch') as data_bar: for batch in data_bar: batch: DialogMiniBatch = batch # type annotation self.model.zero_grad() out = self.model(batch) num_toks = batch.tot_resp_toks.float().item() # before = copy.deepcopy(self.model.state_dict()) loss = self.loss_func(out, batch.resp_seqs, num_toks, True) # after = copy.deepcopy(self.model.state_dict()) # self.diff(before, after) self.tbd.add_scalars('training', { 'step_loss': loss, 'learn_rate': self.opt.curr_lr }, self.opt.curr_step) progress_msg, is_check_pt = train_state.step(num_toks, loss) progress_msg += f', LR={self.opt.curr_lr:g}' data_bar.set_postfix_str(progress_msg, refresh=False) del batch # TODO: force free memory if is_check_pt: train_loss = train_state.reset() train_state.train_mode(False) self.make_check_point(val_data, train_loss, keep_models=keep_models) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True)
def new(self, parameters, lr=0.001, **args): log.info( f"Creating {self.value} optimizer with lr={lr} and extra args:{args}" ) log.info( f" {self.value}, default arguments {inspect.signature(self.value)}" ) return self.value(parameters, lr=lr, **args)
def write_lines(path: Union[str, Path], lines): count = 0 with IO.writer(path) as out: for line in lines: count += 1 out.write(line.strip()) out.write("\n") log.info(f"Wrote {count} lines to {path}")
def pre_process_train_dev(self, args: Dict[str, Any]): # character names vocabulary if self.char_field and self._char_field_file.exists(): log.warning( "Skipping character vocab creating. since it already exists") self.char_field = LookupField(self._char_field_file) else: char_min_freq = args.get('char_min_freq', 500) log.info( f"Scanning characters in training data with with freq {char_min_freq}" ) char_names = self.scan_characters(args['train_dialogs'], min_freq=char_min_freq) log.info(f"Found {len(char_names)} characters") self.write_lines(self._char_field_file, char_names) self.char_field = LookupField(self._char_field_file) # Dialog Text vocabulary if self._text_field_file.exists() and self.text_field is not None: log.warning("Skipping the vocab creation since it already exist") self.text_field = Field(self._text_field_file) else: files = [args['vocab_text']] no_split_toks = args.get('no_split_toks') self.text_field = Field.train(args['pieces'], args['max_types'], str(self._text_field_file), files, no_split_toks=no_split_toks) # create Piece IDs for key, out_path, sample_wt in \ [('train_dialogs', self.train_file, True), ('valid_dialogs', self.valid_file, False)]: dialogs = RawDialogReader(args[key], text_field=self.text_field, char_field=self.char_field, max_seq_len=args['max_seq_len']) if sample_wt: dialogs = list( dialogs) # if this causes OOM, re-read this file # generate weights for sampling weights = sampling_weights(cluster(dialogs).values()) for dlg in dialogs: for utter in dlg.chat: utter.weight = weights[utter.uid] self.write_dialogs(dialogs, out_path) if args.get("finetune_src") or args.get("finetune_tgt"): self.pre_process_finetune(args) # get samples from validation set n_samples = args.get('num_samples', 5) samples = self.pick_samples(Path(args['valid_dialogs']), n_samples) self.write_dialogs(samples, self.samples_file)
def save(self, path): log.info(f"Storing to {path}") # The reason for doing this crazy stuff is to increase the portability of models # if we simply dump object as pickle, then torch version must be matched during re-loading # So we dump only the state params and arrays state = dict(msg_reprs=self.msg_reprs, resp_reprs=self.resp_reprs, resps=self.resps, msgs=self.msgs) torch.save(state, path)
def __init__(self, model_size, factor, warmup, optimizer, step=0): self.optimizer = optimizer self._step = step self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 log.info( f"model_size={model_size}, factor={factor}, warmup={warmup}, step={step}" )
def write_out(pairs, out): """ this func writes pairs as TSV records :param pairs: iterator to read pairs :param out: file stream to write output :return: """ count = 0 for rec in pairs: out.write("\t".join(rec) + "\n") count += 1 log.info(f"Wrote {count} recs to {out.name}")
def write_out(triples, out): """ this func just writes triple records as TSV records :param triples: iterator to read triples :param out: file stream to write output :return: """ count = 0 for c, m, r in triples: rec = list(c) + list(m) + list(r) out.write("\t".join(rec) + "\n") count += 1 log.info(f"Wrote {count} recs to {out.name}")
def __iter__(self): if self.shuffle: if not self._mem: log.info( "Going to shuffle using a buffer. If this causes OOM, don't blame me!" ) self._mem = list(self.read_all(self.path, add_eos=self.add_eos)) random.shuffle(self._mem) dialogs = self._mem else: dialogs = self.read_all(self.path, add_eos=self.add_eos) yield from dialogs
def write_dialogs(dialogs: Iterator[Dialog], out: Path, dialog_sep='\n'): count = 0 with IO.writer(out) as outh: for dialog in dialogs: count += 1 for utter in dialog.chat: if utter.uid: outh.write(f'{utter.uid}\t') if utter.weight: outh.write(f'{utter.weight:g}\t') text = " ".join(map(str, utter.text)) outh.write(f'{utter.char}\t{text}\n') outh.write(dialog_sep) log.info(f"Wrote {count} recs to {out}")
def __init__(self, d_model, text_vocab, char_vocab, char_emb_size=None): super().__init__() self.text_emb = nn.Embedding(text_vocab, d_model) self.char_emb_size = char_emb_size if char_emb_size > 0: # Zero or a negative value disables this log.info(f"Character embeddings enabled: dim={char_emb_size}") self.char_emb = nn.Embedding(char_vocab, self.char_emb_size) self.merge = nn.Linear(self.char_emb_size + d_model, d_model) else: log.info("Character embeddings disabled") self.char_emb = None self.merge = None self.d_model = d_model
def __iter__(self) -> Iterator[Dialog]: count = 0 dialog = Dialog() for line in self.reader: line = line.strip() if not line: if len(dialog) > 0: yield dialog count += 1 dialog = Dialog() continue parts = line.split("\t") if len(parts) < self.num_cols: log.error(f"Skipping the line: {line}") continue self.num_cols = max(len(parts), self.num_cols) raw_char, raw_text = parts[-2:] uid = parts[0] if len(parts) > 2 else None char = raw_char = raw_char.strip() if self.char_field: char = self.char_field.encode_as_id(raw_char) if self.text_field: seq = self.text_field.encode_as_ids(raw_text, add_eos=True) else: seq = raw_text.strip().split() if self.add_eos and seq[-1] != EOS_TOK[0]: seq.append(EOS_TOK[0]) if len(seq) > self.max_seq_len: seq = seq[:self.max_seq_len - 1] seq.append(EOS_TOK_IDX if self.text_field else EOS_TOK[0]) utter = Utterance(char, seq, raw_text=raw_text, raw_char=raw_char, uid=uid) dialog.append(utter) if len(dialog) > 0: count += 1 yield dialog log.info(f"Read {count} dialogs") try: self.reader.close() except: pass
def show_samples(self, beam_size=5, num_hyp=5, max_len=30, skip_top=0): """ Logs the output of model (at this stage in training) to a set of samples :param beam_size: beam size :param num_hyp: number of hypothesis to output :param max_len: maximum length to decode :param skip_top: number of top beams to skip (to improve diversity) :return: """ if not self.samples: log.info("No samples are chosen by the experiment") return self.decoder.decode_dialogs(self.samples, out=None, beam_size=beam_size, num_hyp=num_hyp, max_len=max_len, skip_top=skip_top)
def pre_process_finetune(self, args=None): """ Pre process records for fine tuning :param args: :return: """ log.info("Going to prep fine tune files") args = args if args else self.config['prep'] assert 'finetune_dialogs' in args dialogs = RawDialogReader(args['finetune_dialogs'], text_field=self.text_field, char_field=self.char_field, max_seq_len=args['max_seq_len']) dialogs = list(dialogs) weights = sampling_weights(cluster(dialogs).values()) for dlg in dialogs: for utter in dlg.chat: utter.weight = weights[utter.uid] self.write_dialogs(dialogs, self.finetune_file)
def __iter__(self): count = 0 utters: OrderedSet = OrderedSet() chats: List[ChatRec] = list() def utters_space(): return self.max_utters - len(utters) def chat_space(): return self.max_chats - len(chats) for dialog in self.reader: for chat in dialog.as_mini_chats(min_ctx=self.min_ctx, max_ctx=self.max_ctx, model_chars=self.model_chars, min_resp_len=self.min_resp_len, no_repeat=self.no_repeat, down_sample=self.down_sample): utters.maybe_update( chat.context ) # this might exceed max_utters, but that's okay utters.maybe_add(chat.response) chats.append(chat) if utters_space() <= 0 or chat_space() <= 0: batch = DialogMiniBatchRaw.new(utters.to_list(), chats=chats, sort_desc=self.sort_desc, pad=self.pad) yield batch count += 1 utters.clear() chats.clear() if chats: # left over in the buffer yield DialogMiniBatchRaw.new(utters.to_list(), chats=chats, sort_desc=self.sort_desc, pad=self.pad) count += 1 if count != self.last_count: log.info(f"Produced {count} dialog batches") self.last_count = count
def __test_seq2seq_model__(): work_dir = '/Users/tg/work/phd/cs644/project/virtchar/tmp.work' exp = Experiment(work_dir, read_only=True) text_vocab = len(exp.text_field) char_vocab = len(exp.char_field) emb_size = 100 char_emb_size = 50 step_size = 50 model_dim = 100 steps = 2000 check_pt = 10 log.info(f"====== VOCAB={text_vocab}, Characters:{char_vocab}======") model, args = HRED.make_model(text_vocab=text_vocab, char_vocab=char_vocab, text_emb_size=emb_size, char_emb_size=char_emb_size, hid_size=model_dim, n_layers=1) trainer = SteppedHREDTrainer(exp=exp, model=model, lr=0.01, warmup_steps=500) trainer.train(steps=steps, step_size=step_size, check_point=check_pt)
def log_tensor_sizes(writer=log.info, min_size=1024): """ Forces garbage collector and logs all the current tensors :return: """ log.info("Collecting tensor allocations") gc.collect() def is_tensor(obj): if torch.is_tensor(obj): return True try: # some native objects raise exceptions return hasattr(obj, 'data') and torch.is_tensor(obj.data) except: return False tensors = filter(is_tensor, gc.get_objects()) stats = ((reduce(op.mul, obj.size()) if len(obj.size()) > 0 else 0, obj.type(), tuple(obj.size()), hex(id(obj))) for obj in tensors) stats = ((n * tensor_size[typ], n, typ, *blah) for n, typ, *blah in stats) stats = (x for x in stats if x[0] > min_size) sorted_stats = sorted(stats, key=lambda x: x[0]) writer("####\tApprox Bytes\tItems \tShape \tObject ID") lines = (f'{i:4}\t{size:12,}\t{n:12,}\t{typ}\t{shape}\t{_id}' for i, (size, n, typ, shape, _id) in enumerate(sorted_stats)) log.info("==== Tensors and memories === ") for i, l in enumerate(lines): writer(l) total = sum(rec[0] for rec in sorted_stats) log.info( f'Total Bytes by tensors bigger than {min_size} is (approx):{total:,}' )
def __init__(self, work_dir: Union[str, Path], read_only=False, config: Optional[Dict[str, Any]] = None): if type(work_dir) is str: work_dir = Path(work_dir) log.info(f"Initializing an experiment. Directory = {work_dir}") self.read_only = read_only self.work_dir = work_dir self.data_dir = work_dir / 'data' self.model_dir = work_dir / 'models' self._config_file = work_dir / 'conf.yml' self._text_field_file = self.data_dir / 'text.model' self._char_field_file = self.data_dir / 'vocab.char.txt' self._prepared_flag = self.work_dir / '_PREPARED' self._trained_flag = self.work_dir / '_TRAINED' self.train_file = self.data_dir / 'train.tsv.gz' self.finetune_file = self.data_dir / 'finetune.tsv.gz' self.valid_file = self.data_dir / 'valid.tsv.gz' # a set of samples to watch the progress qualitatively self.samples_file = self.data_dir / 'samples.tsv.gz' if not read_only: for _dir in [self.model_dir, self.data_dir]: if not _dir.exists(): _dir.mkdir(parents=True) if type(config) is str: config = load_conf(config) self.config = config if config else load_conf(self._config_file) self.text_field = Field(str(self._text_field_file)) \ if self._text_field_file.exists() else None self.char_field = LookupField(str(self._char_field_file)) \ if self._char_field_file.exists() else None # these are the characters to which we optimize the loss self._model_chars = None
def train(self, steps: int, check_point: int, fine_tune=False, check_pt_callback: Optional[Callable] = None, **args): log.info(f'Going to train for {steps} steps; ' f'check point size:{check_point}; fine tune={fine_tune}') keep_models = args.get('keep_models', 4) # keep last _ models and delete the old if steps <= self.start_step: raise Exception(f'The model was already trained to {self.start_step} steps. ' f'Please increase the steps or clear the existing models') train_data = self.exp.get_train_data(loop_steps=steps - self.start_step) val_data = self.exp.get_val_data() train_state = TrainerState(self.model, check_point=check_point) train_state.train_mode(True) with tqdm(train_data, initial=self.start_step, total=steps, unit='batch') as data_bar: for batch in data_bar: # Step clear gradients self.model.zero_grad() # Step Run forward pass. outp_log_probs = self.model(batch) loss = self.loss_func(outp_log_probs, batch, True) self.tbd.add_scalars('training', {'step_loss': loss, 'learn_rate': self.opt.curr_lr}, self.opt.curr_step) bar_msg, is_check_pt = train_state.step(batch.tot_resp_toks.item(), loss) bar_msg += f', LR={self.opt.curr_lr:g}' data_bar.set_postfix_str(bar_msg, refresh=False) del batch # TODO: force free memory if is_check_pt: train_loss = train_state.reset() train_state.train_mode(False) self.make_check_point(val_data, train_loss, keep_models=keep_models) if check_pt_callback: check_pt_callback(model=self.model, step=self.opt.curr_step, train_loss=train_loss) train_state.train_mode(True)
def get_train_data(self, shuffle=False, fine_tune=False, loop_steps=0, sort_dec=True) \ -> Iterator[DialogMiniBatch]: assert not shuffle, 'Not supported at the moment' inp_file = self.train_file if fine_tune: if not self.finetune_file.exists(): # user may have added fine tune file later self.pre_process_finetune() log.info("Using Fine tuning corpus instead of training corpus") inp_file = self.finetune_file reader = DialogReader(inp_file) train_data = DialogBatchReader(reader, min_ctx=self.min_ctx, max_ctx=self.max_ctx, max_chats=self.max_chats, max_utters=self.max_utters, model_chars=None, min_resp_len=self.min_resp_len, no_repeat=self.no_repeat, sort_desc=sort_dec) return LoopingIterable( train_data, total=loop_steps) if loop_steps > 0 else train_data
def make_check_point(self, val_data: Iterator[DialogMiniBatch], train_loss: float, keep_models: int): """ Check point the model :param val_data: validation data to obtain validation score :param train_loss: training loss value :param keep_models: how many checkpoints to keep on file system :return: """ step_num = self.opt.curr_step val_loss = self.run_valid_epoch(val_data) log.info( f"Checkpoint at step {step_num}. Training Loss {train_loss:g}," f" Validation Loss:{val_loss:g}") self.show_samples() self.tbd.add_scalars(f'losses', { 'train_loss': train_loss, 'valid_loss': val_loss }, step_num) # Unwrap model state from DataParallel and persist model = (self.model.module if isinstance(self.model, nn.DataParallel) else self.model) state = { 'model_state': model.state_dict(), 'optim_state': self.opt.optimizer.state_dict(), 'step': step_num, 'train_loss': train_loss, 'val_loss': val_loss, 'time': time.time(), 'rtg_version': virtchar.__version__ } self.exp.store_model(step_num, state, train_score=train_loss, val_score=val_loss, keep=keep_models)
def new(cls, exp: DialogExperiment, model=None, gen_args=None, model_paths: Optional[List[str]] = None, ensemble: int = 1): """ create a new decoder :param exp: experiment :param model: Optional pre initialized model :param gen_args: any optional args needed for generator :param model_paths: optional model paths :param ensemble: number of models to use for ensembling (if model is not specified) :return: """ if model is None: factory = factories[exp.model_type] model = factory(**exp.model_args)[0] state = cls.maybe_ensemble_state(exp, model_paths=model_paths, ensemble=ensemble) model.load_state_dict(state) log.info("Successfully restored the model state.") elif isinstance(model, nn.DataParallel): model = model.module model = model.eval().to(device=device) generator = generators[exp.model_type] return cls(model, generator, exp, gen_args)