Esempio n. 1
0
def main():
    args = parse_args()
    seed = args.pop("seed")
    if seed:
        log.info(f"Seed for random number generator: {seed}")
        import random
        import torch
        random.seed(seed)
        torch.manual_seed(seed)

    exp = Experiment(args.pop('work_dir'))
    assert exp.has_prepared(), f'Experiment dir {exp.work_dir} is not ready to train. ' \
                               f'Please run "prep" sub task'
    _, optim_args = exp.optim_args
    if optim_args is None:
        optim_args = {}
    if args.get('optim_args'):
        # convert key1=val1,key2=val2 format to dictionary
        pairs = [x.strip() for x in args.pop('optim_args').split(',')]
        pairs = [pair.split('=') for pair in pairs if pair]
        optim_args.update({k.strip(): float(v) for k, v in pairs})

    trainer = {
        'TRANSFORMER': HieroTransformerTrainer,
        'HRED': SteppedHREDTrainer,
    }[exp.model_type](exp, optim=args.pop('optim'), **optim_args)
    try:
        trainer.train(**args)
    except RuntimeError as e:
        if 'out of memory' in str(e).lower():
            log_tensor_sizes()
        raise e
Esempio n. 2
0
    def decode_dialogs(self, dialogs: Iterator[Dialog], out, verbose=True, **args):
        min_ctx, max_ctx = self.exp.min_ctx, self.exp.max_ctx
        test_chars = None
        for i, dialog in enumerate(dialogs):
            if out:
                # write out the context
                for utter in dialog.chat[:min_ctx]:
                    line = "CTX\t"  # this is a context
                    line += f"{utter.uid}\t" if utter.uid else ""
                    line += f"{utter.raw_char}\t{utter.raw_text}\n"
                    out.write(line)

            chats: Iterator[ChatRec] = dialog.as_test_chats(min_ctx=min_ctx, max_ctx=max_ctx,
                                                            test_chars=test_chars)
            for j, chat in enumerate(chats):
                # One chat in batch. Should/can be improved later
                batch = chat.as_dialog_mini_batch()
                result = self.generate_chat(batch, **args)

                if verbose:
                    log.info(f"dialog: {i}: chat: {j} :: \n"
                             f"MSG: {chat.context[-1].raw_char}: {chat.context[-1].raw_text}\n"
                             f"RSP: {chat.response.raw_char}: {chat.response.raw_text}")
                    out_line = '\n'.join(f'{hyp}\t{score:.4f}' for score, hyp in result)
                    log.info(f"OUT:\n{out_line} \n")
                if out:
                    resp = chat.response
                    line = f"GEN\t"     # Generated
                    line += f"{resp.uid}\t" if resp.uid else ""
                    line += f"{resp.raw_char}\t{resp.raw_text}\t"  # Reference Text
                    line += "\t".join([f"{hyp}\t{score:g}" for score, hyp in result])
                    line += "\n"
                    out.write(line)

            out.write("\n")      # dialog seperator
Esempio n. 3
0
 def read_all(path: Path, add_eos):
     count = 0
     with IO.reader(path) as reader:
         dialog = Dialog()
         for line in reader:
             line = line.strip()
             if line:
                 parts = line.split("\t")
                 char, seq = parts[-2:]  # the last two are mandatory
                 uid = parts[0] if len(parts) > 2 else None
                 weight = float(parts[1]) if len(parts) > 3 else None
                 char, seq = int(char), [
                     int(x) for x in seq.strip().split()
                 ]
                 if add_eos and seq[-1] != EOS_TOK_IDX:
                     seq.append(EOS_TOK_IDX)
                 dialog.append(Utterance(char, seq, uid=uid, weight=weight))
             else:
                 if len(dialog) > 0:
                     yield dialog
                     count += 1
                     dialog = Dialog()
         if len(dialog) > 0:
             count += 1
             yield dialog
     log.info(f"Read {count} dialogs")
Esempio n. 4
0
 def write_tsv(records: Iterator[DialogRecord], path: Union[str, Path]):
     seqs = ((str(x), ' '.join(map(str, y))) for x, y in records)
     lines = (f'{x}\t{y}\n' for x, y in seqs)
     log.info(f"Storing data at {path}")
     with IO.writer(path) as f:
         for line in lines:
             f.write(line)
Esempio n. 5
0
    def store_model(self, step: int, model, train_score: float,
                    val_score: float, keep: int):
        """
        saves model to a given path
        :param step: step number of training
        :param model: model object itself
        :param train_score: score of model on training split
        :param val_score: score of model on validation split
        :param keep: number of good models to keep, bad models will be deleted
        :return:
        """
        # TODO: improve this by skipping the model save if the model is not good enough to be saved
        if self.read_only:
            log.warning("Ignoring the store request; experiment is readonly")
            return
        name = f'model_{step:05d}_{train_score:.6f}_{val_score:.6f}.pkl'
        path = self.model_dir / name
        log.info(f"Saving... step={step} to {path}")
        torch.save(model, str(path))

        for bad_model in self.list_models(sort='total_score',
                                          desc=False)[keep:]:
            log.info(f"Deleting bad model {bad_model} . Keep={keep}")
            os.remove(str(bad_model))

        with IO.writer(os.path.join(self.model_dir, 'scores.tsv'),
                       append=True) as f:
            cols = [
                str(step),
                datetime.now().isoformat(), name, f'{train_score:g}',
                f'{val_score:g}'
            ]
            f.write('\t'.join(cols) + '\n')
Esempio n. 6
0
    def __init__(self,
                 exp: DialogExperiment,
                 model: Optional[HieroTransformer] = None,
                 optim: str = 'ADAM',
                 **optim_args):
        super().__init__(exp,
                         model,
                         model_factory=HieroTransformer.make_model,
                         optim=optim,
                         **optim_args)

        device_ids = list(range(torch.cuda.device_count()))
        log.info(
            f"Going to use {torch.cuda.device_count()} GPU(s) ; ids:{device_ids}"
        )

        if len(device_ids) > 1:  # Multi GPU mode
            raise Exception("Multi GPU mode not supported yet")
        generator = self.model.generator

        criterion = LabelSmoothing(vocab_size=generator.vocab,
                                   padding_idx=PAD_IDX,
                                   smoothing=self._smoothing)

        self.loss_func = SimpleLossFunction(generator, criterion, opt=self.opt)
Esempio n. 7
0
 def train(model_type: str,
           vocab_size: int,
           model_path: str,
           files: Iterator[str],
           no_split_toks: Optional[List[str]] = None):
     """
     Train Sentence Piece Model
     :param model_type: sentence piece model type: {unigram, BPE, word, char}
     :param vocab_size: target vocabulary size
     :param model_path: where to store model
     :param files: input files
     :param no_split_toks: Don't split these tokens
     :return:
     """
     model_prefix = model_path.replace('.model', '')
     files = set(files)  # remove duplicates
     arg = f"--input={','.join(files)} --vocab_size={vocab_size} --model_prefix={model_prefix}" \
           f" --model_type={model_type} --pad_id={PAD_TOK[1]} --bos_id={BOS_TOK[1]}" \
           f" --eos_id={EOS_TOK[1]} --unk_id={UNK_TOK[1]} --hard_vocab_limit=false"
     if no_split_toks:
         arg += f" --user_defined_symbols={','.join(no_split_toks)}"
     log.info(f"SPM: {arg}")
     SentencePieceTrainer.Train(arg)
     log.info("Training complete")
     if not model_path.endswith('.model'):
         model_path += '.model'
     return Field(model_path)
Esempio n. 8
0
 def __init__(self,
              utter_encoder: Encoder,
              ctx_encoder: Encoder,
              decoder: Decoder,
              src_inp_embs: ComboEmbeddings,
              tgt_inp_embs: ComboEmbeddings,
              generator: Generator,
              dropout: float,
              sent_repr_mode: str = 'cls'):
     super().__init__()
     self.utter_encoder = utter_encoder
     self.ctx_encoder = ctx_encoder
     self.decoder = decoder
     self.src_inp_embs: ComboEmbeddings = src_inp_embs
     self.tgt_inp_embs: ComboEmbeddings = tgt_inp_embs
     self.generator = generator
     self._model_dim = generator.d_model
     assert sent_repr_mode in ('sum', 'cls')
     log.info(f"Sentence Representation mode :: {sent_repr_mode}")
     self.sent_repr_mode = sent_repr_mode
     if sent_repr_mode == 'sum':
         log.warning(
             "warning: summing the vectors didn't help in previous runs")
         self.sent_repr_conn = SublayerConnection(self._model_dim, dropout)
     # positional encoder for the chat sequence
     self.posit_enc = PositionalEncoding(self._model_dim, dropout=dropout)
Esempio n. 9
0
    def train(self,
              steps: int,
              check_point: int,
              check_pt_callback: Optional[Callable] = None,
              fine_tune=False,
              **args):
        log.info(f'Going to train for {steps} epochs; '
                 f'check point size:{check_point}; fine_tune={fine_tune}')
        keep_models = args.get('keep_models',
                               4)  # keep last _ models and delete the old

        if steps <= self.start_step:
            raise Exception(
                f'The model was already trained to {self.start_step} steps. '
                f'Please increase the steps or clear the existing models')
        train_data = self.exp.get_train_data(loop_steps=steps -
                                             self.start_step,
                                             fine_tune=fine_tune,
                                             sort_dec=False)
        val_data = self.exp.get_val_data(sort_dec=False)

        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)

        with tqdm(train_data,
                  initial=self.start_step,
                  total=steps,
                  unit='batch') as data_bar:
            for batch in data_bar:
                batch: DialogMiniBatch = batch  # type annotation
                self.model.zero_grad()
                out = self.model(batch)
                num_toks = batch.tot_resp_toks.float().item()
                # before = copy.deepcopy(self.model.state_dict())
                loss = self.loss_func(out, batch.resp_seqs, num_toks, True)
                # after = copy.deepcopy(self.model.state_dict())
                # self.diff(before, after)
                self.tbd.add_scalars('training', {
                    'step_loss': loss,
                    'learn_rate': self.opt.curr_lr
                }, self.opt.curr_step)

                progress_msg, is_check_pt = train_state.step(num_toks, loss)
                progress_msg += f', LR={self.opt.curr_lr:g}'

                data_bar.set_postfix_str(progress_msg, refresh=False)
                del batch  # TODO: force free memory

                if is_check_pt:
                    train_loss = train_state.reset()
                    train_state.train_mode(False)
                    self.make_check_point(val_data,
                                          train_loss,
                                          keep_models=keep_models)
                    if check_pt_callback:
                        check_pt_callback(model=self.model,
                                          step=self.opt.curr_step,
                                          train_loss=train_loss)
                    train_state.train_mode(True)
Esempio n. 10
0
 def new(self, parameters, lr=0.001, **args):
     log.info(
         f"Creating {self.value} optimizer with lr={lr} and extra args:{args}"
     )
     log.info(
         f"   {self.value}, default arguments {inspect.signature(self.value)}"
     )
     return self.value(parameters, lr=lr, **args)
Esempio n. 11
0
 def write_lines(path: Union[str, Path], lines):
     count = 0
     with IO.writer(path) as out:
         for line in lines:
             count += 1
             out.write(line.strip())
             out.write("\n")
         log.info(f"Wrote {count} lines to {path}")
Esempio n. 12
0
    def pre_process_train_dev(self, args: Dict[str, Any]):

        # character names vocabulary
        if self.char_field and self._char_field_file.exists():
            log.warning(
                "Skipping character vocab creating. since it already exists")
            self.char_field = LookupField(self._char_field_file)
        else:
            char_min_freq = args.get('char_min_freq', 500)
            log.info(
                f"Scanning characters in training data with with freq {char_min_freq}"
            )
            char_names = self.scan_characters(args['train_dialogs'],
                                              min_freq=char_min_freq)
            log.info(f"Found {len(char_names)} characters")
            self.write_lines(self._char_field_file, char_names)
            self.char_field = LookupField(self._char_field_file)

        # Dialog Text vocabulary
        if self._text_field_file.exists() and self.text_field is not None:
            log.warning("Skipping the vocab creation since it already exist")
            self.text_field = Field(self._text_field_file)
        else:
            files = [args['vocab_text']]
            no_split_toks = args.get('no_split_toks')
            self.text_field = Field.train(args['pieces'],
                                          args['max_types'],
                                          str(self._text_field_file),
                                          files,
                                          no_split_toks=no_split_toks)

        # create Piece IDs
        for key, out_path, sample_wt in \
                [('train_dialogs', self.train_file, True),
                 ('valid_dialogs', self.valid_file, False)]:
            dialogs = RawDialogReader(args[key],
                                      text_field=self.text_field,
                                      char_field=self.char_field,
                                      max_seq_len=args['max_seq_len'])
            if sample_wt:
                dialogs = list(
                    dialogs)  # if this causes OOM, re-read this file
                # generate weights for sampling
                weights = sampling_weights(cluster(dialogs).values())
                for dlg in dialogs:
                    for utter in dlg.chat:
                        utter.weight = weights[utter.uid]

            self.write_dialogs(dialogs, out_path)

        if args.get("finetune_src") or args.get("finetune_tgt"):
            self.pre_process_finetune(args)

        # get samples from validation set
        n_samples = args.get('num_samples', 5)
        samples = self.pick_samples(Path(args['valid_dialogs']), n_samples)
        self.write_dialogs(samples, self.samples_file)
Esempio n. 13
0
 def save(self, path):
     log.info(f"Storing to {path}")
     # The reason for doing this crazy stuff is to increase the portability of models
     # if we simply dump object as pickle, then torch version must be matched during re-loading
     # So we dump only the state params and arrays
     state = dict(msg_reprs=self.msg_reprs,
                  resp_reprs=self.resp_reprs,
                  resps=self.resps,
                  msgs=self.msgs)
     torch.save(state, path)
Esempio n. 14
0
 def __init__(self, model_size, factor, warmup, optimizer, step=0):
     self.optimizer = optimizer
     self._step = step
     self.warmup = warmup
     self.factor = factor
     self.model_size = model_size
     self._rate = 0
     log.info(
         f"model_size={model_size}, factor={factor}, warmup={warmup}, step={step}"
     )
Esempio n. 15
0
def write_out(pairs, out):
    """
    this func  writes pairs as TSV records
    :param pairs: iterator to read pairs
    :param out: file stream to write output
    :return:
    """
    count = 0
    for rec in pairs:
        out.write("\t".join(rec) + "\n")
        count += 1
    log.info(f"Wrote {count} recs to {out.name}")
Esempio n. 16
0
def write_out(triples, out):
    """
    this func just writes triple records as TSV records
    :param triples: iterator to read triples
    :param out: file stream to write output
    :return:
    """
    count = 0
    for c, m, r in triples:
        rec = list(c) + list(m) + list(r)
        out.write("\t".join(rec) + "\n")
        count += 1
    log.info(f"Wrote {count} recs to {out.name}")
Esempio n. 17
0
 def __iter__(self):
     if self.shuffle:
         if not self._mem:
             log.info(
                 "Going to shuffle using a buffer. If this causes OOM, don't blame me!"
             )
             self._mem = list(self.read_all(self.path,
                                            add_eos=self.add_eos))
         random.shuffle(self._mem)
         dialogs = self._mem
     else:
         dialogs = self.read_all(self.path, add_eos=self.add_eos)
     yield from dialogs
Esempio n. 18
0
 def write_dialogs(dialogs: Iterator[Dialog], out: Path, dialog_sep='\n'):
     count = 0
     with IO.writer(out) as outh:
         for dialog in dialogs:
             count += 1
             for utter in dialog.chat:
                 if utter.uid:
                     outh.write(f'{utter.uid}\t')
                 if utter.weight:
                     outh.write(f'{utter.weight:g}\t')
                 text = " ".join(map(str, utter.text))
                 outh.write(f'{utter.char}\t{text}\n')
             outh.write(dialog_sep)
     log.info(f"Wrote {count} recs to {out}")
Esempio n. 19
0
    def __init__(self, d_model, text_vocab, char_vocab, char_emb_size=None):
        super().__init__()

        self.text_emb = nn.Embedding(text_vocab, d_model)
        self.char_emb_size = char_emb_size
        if char_emb_size > 0:  # Zero or a negative value disables this
            log.info(f"Character embeddings enabled: dim={char_emb_size}")
            self.char_emb = nn.Embedding(char_vocab, self.char_emb_size)
            self.merge = nn.Linear(self.char_emb_size + d_model, d_model)
        else:
            log.info("Character embeddings disabled")
            self.char_emb = None
            self.merge = None
        self.d_model = d_model
Esempio n. 20
0
    def __iter__(self) -> Iterator[Dialog]:
        count = 0
        dialog = Dialog()
        for line in self.reader:
            line = line.strip()
            if not line:
                if len(dialog) > 0:
                    yield dialog
                    count += 1
                    dialog = Dialog()
                continue
            parts = line.split("\t")
            if len(parts) < self.num_cols:
                log.error(f"Skipping the line: {line}")
                continue
            self.num_cols = max(len(parts), self.num_cols)
            raw_char, raw_text = parts[-2:]
            uid = parts[0] if len(parts) > 2 else None
            char = raw_char = raw_char.strip()
            if self.char_field:
                char = self.char_field.encode_as_id(raw_char)
            if self.text_field:
                seq = self.text_field.encode_as_ids(raw_text, add_eos=True)
            else:
                seq = raw_text.strip().split()
                if self.add_eos and seq[-1] != EOS_TOK[0]:
                    seq.append(EOS_TOK[0])
            if len(seq) > self.max_seq_len:
                seq = seq[:self.max_seq_len - 1]
                seq.append(EOS_TOK_IDX if self.text_field else EOS_TOK[0])

            utter = Utterance(char,
                              seq,
                              raw_text=raw_text,
                              raw_char=raw_char,
                              uid=uid)
            dialog.append(utter)
        if len(dialog) > 0:
            count += 1
            yield dialog
        log.info(f"Read {count} dialogs")

        try:
            self.reader.close()
        except:
            pass
Esempio n. 21
0
 def show_samples(self, beam_size=5, num_hyp=5, max_len=30, skip_top=0):
     """
     Logs the output of model (at this stage in training) to a set of samples
     :param beam_size: beam size
     :param num_hyp: number of hypothesis to output
     :param max_len: maximum length to decode
     :param skip_top: number of top beams to skip (to improve diversity)
     :return:
     """
     if not self.samples:
         log.info("No samples are chosen by the experiment")
         return
     self.decoder.decode_dialogs(self.samples,
                                 out=None,
                                 beam_size=beam_size,
                                 num_hyp=num_hyp,
                                 max_len=max_len,
                                 skip_top=skip_top)
Esempio n. 22
0
 def pre_process_finetune(self, args=None):
     """
     Pre process records for fine tuning
     :param args:
     :return:
     """
     log.info("Going to prep fine tune files")
     args = args if args else self.config['prep']
     assert 'finetune_dialogs' in args
     dialogs = RawDialogReader(args['finetune_dialogs'],
                               text_field=self.text_field,
                               char_field=self.char_field,
                               max_seq_len=args['max_seq_len'])
     dialogs = list(dialogs)
     weights = sampling_weights(cluster(dialogs).values())
     for dlg in dialogs:
         for utter in dlg.chat:
             utter.weight = weights[utter.uid]
     self.write_dialogs(dialogs, self.finetune_file)
Esempio n. 23
0
    def __iter__(self):
        count = 0
        utters: OrderedSet = OrderedSet()
        chats: List[ChatRec] = list()

        def utters_space():
            return self.max_utters - len(utters)

        def chat_space():
            return self.max_chats - len(chats)

        for dialog in self.reader:
            for chat in dialog.as_mini_chats(min_ctx=self.min_ctx,
                                             max_ctx=self.max_ctx,
                                             model_chars=self.model_chars,
                                             min_resp_len=self.min_resp_len,
                                             no_repeat=self.no_repeat,
                                             down_sample=self.down_sample):
                utters.maybe_update(
                    chat.context
                )  # this might exceed max_utters, but that's okay
                utters.maybe_add(chat.response)
                chats.append(chat)
                if utters_space() <= 0 or chat_space() <= 0:
                    batch = DialogMiniBatchRaw.new(utters.to_list(),
                                                   chats=chats,
                                                   sort_desc=self.sort_desc,
                                                   pad=self.pad)
                    yield batch
                    count += 1
                    utters.clear()
                    chats.clear()
        if chats:  # left over in the buffer
            yield DialogMiniBatchRaw.new(utters.to_list(),
                                         chats=chats,
                                         sort_desc=self.sort_desc,
                                         pad=self.pad)
            count += 1
        if count != self.last_count:
            log.info(f"Produced {count} dialog batches")
            self.last_count = count
Esempio n. 24
0
def __test_seq2seq_model__():
    work_dir = '/Users/tg/work/phd/cs644/project/virtchar/tmp.work'
    exp = Experiment(work_dir, read_only=True)
    text_vocab = len(exp.text_field)
    char_vocab = len(exp.char_field)

    emb_size = 100
    char_emb_size = 50

    step_size = 50
    model_dim = 100

    steps = 2000
    check_pt = 10

    log.info(f"====== VOCAB={text_vocab}, Characters:{char_vocab}======")
    model, args = HRED.make_model(text_vocab=text_vocab, char_vocab=char_vocab,
                                  text_emb_size=emb_size, char_emb_size=char_emb_size,
                                  hid_size=model_dim, n_layers=1)
    trainer = SteppedHREDTrainer(exp=exp, model=model, lr=0.01, warmup_steps=500)
    trainer.train(steps=steps, step_size=step_size, check_point=check_pt)
Esempio n. 25
0
def log_tensor_sizes(writer=log.info, min_size=1024):
    """
    Forces garbage collector and logs all the current tensors
    :return:
    """
    log.info("Collecting tensor allocations")
    gc.collect()

    def is_tensor(obj):
        if torch.is_tensor(obj):
            return True
        try:  # some native objects raise exceptions
            return hasattr(obj, 'data') and torch.is_tensor(obj.data)
        except:
            return False

    tensors = filter(is_tensor, gc.get_objects())
    stats = ((reduce(op.mul, obj.size()) if len(obj.size()) > 0 else 0,
              obj.type(), tuple(obj.size()), hex(id(obj))) for obj in tensors)
    stats = ((n * tensor_size[typ], n, typ, *blah) for n, typ, *blah in stats)
    stats = (x for x in stats if x[0] > min_size)
    sorted_stats = sorted(stats, key=lambda x: x[0])

    writer("####\tApprox Bytes\tItems       \tShape   \tObject ID")
    lines = (f'{i:4}\t{size:12,}\t{n:12,}\t{typ}\t{shape}\t{_id}'
             for i, (size, n, typ, shape, _id) in enumerate(sorted_stats))
    log.info("==== Tensors and memories === ")
    for i, l in enumerate(lines):
        writer(l)

    total = sum(rec[0] for rec in sorted_stats)
    log.info(
        f'Total Bytes by tensors  bigger than {min_size} is (approx):{total:,}'
    )
Esempio n. 26
0
    def __init__(self,
                 work_dir: Union[str, Path],
                 read_only=False,
                 config: Optional[Dict[str, Any]] = None):
        if type(work_dir) is str:
            work_dir = Path(work_dir)

        log.info(f"Initializing an experiment. Directory = {work_dir}")
        self.read_only = read_only
        self.work_dir = work_dir
        self.data_dir = work_dir / 'data'
        self.model_dir = work_dir / 'models'
        self._config_file = work_dir / 'conf.yml'
        self._text_field_file = self.data_dir / 'text.model'
        self._char_field_file = self.data_dir / 'vocab.char.txt'
        self._prepared_flag = self.work_dir / '_PREPARED'
        self._trained_flag = self.work_dir / '_TRAINED'

        self.train_file = self.data_dir / 'train.tsv.gz'
        self.finetune_file = self.data_dir / 'finetune.tsv.gz'
        self.valid_file = self.data_dir / 'valid.tsv.gz'
        # a set of samples to watch the progress qualitatively
        self.samples_file = self.data_dir / 'samples.tsv.gz'

        if not read_only:
            for _dir in [self.model_dir, self.data_dir]:
                if not _dir.exists():
                    _dir.mkdir(parents=True)
        if type(config) is str:
            config = load_conf(config)
        self.config = config if config else load_conf(self._config_file)

        self.text_field = Field(str(self._text_field_file)) \
            if self._text_field_file.exists() else None
        self.char_field = LookupField(str(self._char_field_file)) \
            if self._char_field_file.exists() else None

        # these are the characters to which we optimize the loss
        self._model_chars = None
Esempio n. 27
0
    def train(self, steps: int, check_point: int, fine_tune=False,
              check_pt_callback: Optional[Callable] = None, **args):
        log.info(f'Going to train for {steps} steps; '
                 f'check point size:{check_point}; fine tune={fine_tune}')
        keep_models = args.get('keep_models', 4)  # keep last _ models and delete the old

        if steps <= self.start_step:
            raise Exception(f'The model was already trained to {self.start_step} steps. '
                            f'Please increase the steps or clear the existing models')
        train_data = self.exp.get_train_data(loop_steps=steps - self.start_step)
        val_data = self.exp.get_val_data()

        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)
        with tqdm(train_data, initial=self.start_step, total=steps, unit='batch') as data_bar:
            for batch in data_bar:
                # Step clear gradients
                self.model.zero_grad()

                # Step Run forward pass.
                outp_log_probs = self.model(batch)
                loss = self.loss_func(outp_log_probs, batch, True)
                self.tbd.add_scalars('training', {'step_loss': loss,
                                                  'learn_rate': self.opt.curr_lr},
                                     self.opt.curr_step)
                bar_msg, is_check_pt = train_state.step(batch.tot_resp_toks.item(), loss)
                bar_msg += f', LR={self.opt.curr_lr:g}'
                data_bar.set_postfix_str(bar_msg, refresh=False)

                del batch  # TODO: force free memory
                if is_check_pt:
                    train_loss = train_state.reset()
                    train_state.train_mode(False)
                    self.make_check_point(val_data, train_loss, keep_models=keep_models)
                    if check_pt_callback:
                        check_pt_callback(model=self.model,
                                          step=self.opt.curr_step,
                                          train_loss=train_loss)
                    train_state.train_mode(True)
Esempio n. 28
0
    def get_train_data(self, shuffle=False, fine_tune=False, loop_steps=0, sort_dec=True) \
            -> Iterator[DialogMiniBatch]:
        assert not shuffle, 'Not supported at the moment'
        inp_file = self.train_file
        if fine_tune:
            if not self.finetune_file.exists():
                # user may have added fine tune file later
                self.pre_process_finetune()
            log.info("Using Fine tuning corpus instead of training corpus")
            inp_file = self.finetune_file

        reader = DialogReader(inp_file)
        train_data = DialogBatchReader(reader,
                                       min_ctx=self.min_ctx,
                                       max_ctx=self.max_ctx,
                                       max_chats=self.max_chats,
                                       max_utters=self.max_utters,
                                       model_chars=None,
                                       min_resp_len=self.min_resp_len,
                                       no_repeat=self.no_repeat,
                                       sort_desc=sort_dec)
        return LoopingIterable(
            train_data, total=loop_steps) if loop_steps > 0 else train_data
Esempio n. 29
0
    def make_check_point(self, val_data: Iterator[DialogMiniBatch],
                         train_loss: float, keep_models: int):
        """
        Check point the model
        :param val_data: validation data to obtain validation score
        :param train_loss: training loss value
        :param keep_models: how many checkpoints to keep on file system
        :return:
        """
        step_num = self.opt.curr_step
        val_loss = self.run_valid_epoch(val_data)
        log.info(
            f"Checkpoint at step {step_num}. Training Loss {train_loss:g},"
            f" Validation Loss:{val_loss:g}")
        self.show_samples()

        self.tbd.add_scalars(f'losses', {
            'train_loss': train_loss,
            'valid_loss': val_loss
        }, step_num)
        # Unwrap model state from DataParallel and persist
        model = (self.model.module
                 if isinstance(self.model, nn.DataParallel) else self.model)
        state = {
            'model_state': model.state_dict(),
            'optim_state': self.opt.optimizer.state_dict(),
            'step': step_num,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'time': time.time(),
            'rtg_version': virtchar.__version__
        }
        self.exp.store_model(step_num,
                             state,
                             train_score=train_loss,
                             val_score=val_loss,
                             keep=keep_models)
Esempio n. 30
0
    def new(cls, exp: DialogExperiment, model=None, gen_args=None,
            model_paths: Optional[List[str]] = None,
            ensemble: int = 1):
        """
        create a new decoder
        :param exp: experiment
        :param model: Optional pre initialized model
        :param gen_args: any optional args needed for generator
        :param model_paths: optional model paths
        :param ensemble: number of models to use for ensembling (if model is not specified)
        :return:
        """
        if model is None:
            factory = factories[exp.model_type]
            model = factory(**exp.model_args)[0]
            state = cls.maybe_ensemble_state(exp, model_paths=model_paths, ensemble=ensemble)
            model.load_state_dict(state)
            log.info("Successfully restored the model state.")
        elif isinstance(model, nn.DataParallel):
            model = model.module

        model = model.eval().to(device=device)
        generator = generators[exp.model_type]
        return cls(model, generator, exp, gen_args)