Ejemplo n.º 1
0
def fork_experiment(from_exp: Path, to_exp: Path, conf: bool, vocab: bool, data: bool, code: bool):
    assert from_exp.exists()
    log.info(f'Fork: {str(from_exp)} → {str(to_exp)}')
    if not to_exp.exists():
        log.info(f"Create dir {str(to_exp)}")
        to_exp.mkdir(parents=True)
    if conf:
        conf_file = to_exp / 'conf.yml'
        IO.maybe_backup(conf_file)
        IO.copy_file(from_exp / 'conf.yml', conf_file)
    if data:
        to_data_dir = (to_exp / 'data')
        from_data_dir = from_exp / 'data'
        if to_data_dir.is_symlink():
            log.info(f"removing the existing data link: {to_data_dir.resolve()}")
            to_data_dir.unlink()
        assert not to_data_dir.exists()
        assert from_data_dir.exists()
        log.info(f"link {to_data_dir} → {from_data_dir}")
        to_data_dir.symlink_to(from_data_dir.resolve())
        (to_exp / '_PREPARED').touch(exist_ok=True)
    if not data and vocab: # just the vocab
        Experiment(from_exp, read_only=True).copy_vocabs(
            Experiment(to_exp, config={'Not': 'Empty'}, read_only=True))

    if code:
        for f in ['rtg.zip', 'githead']:
            src = from_exp / f
            if not src.exists():
                log.warning(f"File Not Found: {src}")
                continue
            IO.copy_file(src, to_exp / f)
Ejemplo n.º 2
0
    def write(cls, path, records: Iterator[ParallelSeqRecord]):
        if path.exists():
            log.warning(f"Overwriting {path} with new records")
            os.remove(str(path))
        maybe_tmp = IO.maybe_tmpfs(path)
        log.info(f'Creating {maybe_tmp}')
        conn = sqlite3.connect(str(maybe_tmp))
        cur = conn.cursor()
        cur.execute(cls.TABLE_STATEMENT)
        cur.execute(cls.INDEX_X_LEN)
        cur.execute(cls.INDEX_Y_LEN)
        cur.execute(f"PRAGMA user_version = {cls.CUR_VERSION};")

        count = 0
        for x_seq, y_seq in records:
            # use numpy. its a lot efficient
            if not isinstance(x_seq, np.ndarray):
                x_seq = np.array(x_seq, dtype=np.int32)
            if y_seq is not None and not isinstance(y_seq, np.ndarray):
                y_seq = np.array(y_seq, dtype=np.int32)
            values = (x_seq.tobytes(),
                      None if y_seq is None else y_seq.tobytes(), len(x_seq),
                      len(y_seq) if y_seq is not None else -1)
            cur.execute(cls.INSERT_STMT, values)
            count += 1
        cur.close()
        conn.commit()
        if maybe_tmp != path:
            # bring the file back to original location where it should be
            IO.copy_file(maybe_tmp, path)
        log.info(f"stored {count} rows in {path}")
Ejemplo n.º 3
0
 def forward(self, x, score=None, gen_probs=True, log_probs=True):
     """
     :param x: features or hidden states
     :param score: what scores are do you want in return? Your options are
         'logits' -- scores without any normalization
         'softmax' -- raw probs for multi class
         'log_softmax' -- log probs for multiclass
         'sigmoid' -- for multilabel task
     :param gen_probs: (deprecated, use 'score=logits') False to get logits; default is True
     :param log_probs: (deprecated, use score='log_softmax' or 'softmax').
         False to get raw probs from softmax, True to get probs from log_softmax.
     :return: scores based on choice of score=xxx
     """
     # made this mess to preserve backward compatibility
     if not score:
         score = 'logits'
         if gen_probs:
             score = 'log_softmax' if log_probs else 'softmax'
         warn_msg = f'API deprecated. use "score={score}" attribute.'
         if warn_msg not in self.warn_msgs:  # warn only Once
             self.warn_msgs.add(warn_msg)
             log.warning(warn_msg)
             traceback.print_stack(limit=6)
     assert score in self.scores, f'{self.scores.keys()} supported but given "{score}"'
     if score == 'embedding' or score == 'identity':
         return x
     x = self.proj(x)
     return self.scores[score](x, dim=-1)
Ejemplo n.º 4
0
def get_stats(data, n, limit=-1):
    n_seqs = 0
    lens = []
    freqs = np.zeros(n, dtype=np.int32)
    for seq in tqdm(data):
        n_seqs += 1
        for i in seq:
            freqs[i] += 1
        lens.append(len(seq))
        if limit > 0 and n_seqs > limit:
            log.warning(
                f"Aborting at {n_seqs} records though here are more in the dataset"
            )
            break
    lens = np.array(lens)
    total_toks = np.sum(lens)
    assert total_toks == np.sum(freqs)  # sanity
    # Zkip zero frequencies; they could be reserved words or from the other side if vocab is shared
    n_zero_types = sum(1 for f in freqs if f == 0)
    n_effective = n - n_zero_types
    probs = freqs / total_toks
    imbalance = 0.5 * np.sum(np.abs(1 / n_effective - probs))
    res = dict(n_seqs=n_seqs,
               total_toks=total_toks,
               mean_len=np.mean(lens),
               median_len=np.median(lens),
               max_len=np.max(lens),
               EMD=imbalance,
               n=n,
               zero_types=n_zero_types,
               effective_n=n_effective)
    return freqs, res
Ejemplo n.º 5
0
def main():
    args = parse_args()
    seed = args.pop("seed")
    if seed:
        log.info(f"Seed for random number generator: {seed}")
        import random
        import torch
        random.seed(seed)
        torch.manual_seed(seed)

    work_dir = Path(args.pop('work_dir'))
    is_big = load_conf(work_dir / 'conf.yml').get('spark', {})

    if is_big:
        log.info("Big experiment mode enabled; checking pyspark backend")
        try:
            import pyspark
        except:
            log.warning("unable to import pyspark. Please do 'pip install pyspark' and run again")
            raise
        from rtg.big.exp import BigTranslationExperiment
        exp = BigTranslationExperiment(work_dir=work_dir)
    else:
        exp = Experiment(work_dir=work_dir)
    assert exp.has_prepared(), f'Experiment dir {exp.work_dir} is not ready to train. ' \
                               f'Please run "prep" sub task'
    exp.train(args)
Ejemplo n.º 6
0
    def train(self,
              steps: int,
              check_point: int,
              batch_size: int,
              check_pt_callback: Optional[Callable] = None,
              side='tgt',
              ctx_size=2,
              **args):
        log.info(f"using side={side}, ctx_size={ctx_size}")
        reader = DataReader(self.exp, side=side)
        rem_steps = steps - self.start_step
        if rem_steps <= 0:
            log.warning(f"Already trained upto {self.start_step-1}. Skipped")
            return
        train_data = reader.get_training_data(batch_size=batch_size,
                                              n_batches=rem_steps,
                                              ctx_size=ctx_size)
        val_data = reader.get_val_data(batch_size=batch_size,
                                       ctx_size=ctx_size)
        train_loss, n = 0.0, 0

        def _make_checkpt(step, train_loss):
            with torch.no_grad():
                val_loss = self.run_valid_epoch(val_data)
            log.info(f"Checkpoint at {step}")
            self.save_embeddings(step, train_loss, val_loss, txt=True)
            self.tbd.add_scalars('losses', {
                'training': train_loss,
                'valid_loss': val_loss
            },
                                 global_step=step)

        with tqdm(train_data,
                  initial=self.start_step,
                  total=rem_steps + 1,
                  unit='batch',
                  dynamic_ncols=True) as data_bar:
            for i, (xs, ys) in enumerate(data_bar, start=self.start_step):
                self.model.zero_grad()
                xs, ys = xs.to(device), ys.to(device)
                log_probs = self.model(xs)
                loss = self.loss_func(log_probs, ys)
                self.tbd.add_scalars('training', {
                    'step_loss': loss.item(),
                    'learn_rate': self.opt.curr_lr
                }, self.opt.curr_step)
                progress_msg = f', loss={loss:g} LR={self.opt.curr_lr:g}'
                data_bar.set_postfix_str(progress_msg, refresh=False)
                train_loss += loss.item()
                n += len(ys)

                loss.backward()
                self.opt.step()
                self.opt.zero_grad()

                if i > 0 and i % check_point == 0:
                    _make_checkpt(i, train_loss / n)
                    train_loss, n = 0.0, 0
        if n > 0:
            _make_checkpt(steps, train_loss / n)
Ejemplo n.º 7
0
    def make_model(cls,
                   src_vocab,
                   tgt_vocab,
                   enc_layers=6,
                   dec_layers=6,
                   hid_size=512,
                   ff_size=2048,
                   n_heads=8,
                   attn_bias=True,
                   attn_dropout=0.1,
                   dropout=0.2,
                   activation='relu',
                   enc_depth_probs: List[float] = (1.0, 0.9, 0.8, 0.7, 0.6,
                                                   0.5),
                   dec_depth_probs: List[float] = (1.0, 0.9, 0.8, 0.7, 0.6,
                                                   0.5),
                   tied_emb='three-way',
                   exp: Experiment = None):
        """Helper: Construct a model from hyper parameters."""
        assert len(enc_depth_probs) == enc_layers
        assert len(dec_depth_probs) == dec_layers
        # get all args for reconstruction at a later phase
        args = get_my_args(exclusions=['cls', 'exp'])

        assert activation in {'relu', 'elu', 'gelu'}
        log.info(f"Make model, Args={args}")
        c = copy.deepcopy
        attn = tfm.MultiHeadedAttention(n_heads,
                                        hid_size,
                                        dropout=attn_dropout,
                                        bias=attn_bias)
        ff = tfm.PositionwiseFeedForward(hid_size,
                                         ff_size,
                                         dropout,
                                         activation=activation)

        if enc_layers == 0:
            log.warning("Zero encoder layers!")
        encoder = SkipEncoder(
            tfm.EncoderLayer(hid_size, c(attn), c(ff), dropout), enc_layers,
            enc_depth_probs)

        assert dec_layers > 0
        decoder = SkipDecoder(
            tfm.DecoderLayer(hid_size, c(attn), c(attn), c(ff), dropout),
            dec_layers, dec_depth_probs)

        src_emb = nn.Sequential(tfm.Embeddings(hid_size, src_vocab),
                                tfm.PositionalEncoding(hid_size, dropout))
        tgt_emb = nn.Sequential(tfm.Embeddings(hid_size, tgt_vocab),
                                tfm.PositionalEncoding(hid_size, dropout))
        generator = tfm.Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)

        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Ejemplo n.º 8
0
 def enable_fp16(self):
     if not self.fp16:  # conditional import
         self.fp16 = True
         self._scaler = GradScaler(enabled=self.fp16)
         log.info("Enabling FP16  /Automatic Mixed Precision training")
     else:
         log.warning(" fp16 is already enabled")
Ejemplo n.º 9
0
    def from_lines(cls,
                   lines: Iterator[str],
                   batch_size: int,
                   vocab: Field,
                   sort=True,
                   max_src_len=0,
                   max_len_buffer=0):
        """
        Note: this changes the order based on sequence length if sort=True
        :param lines: stream of input lines
        :param batch_size: number of tokens in batch
        :param vocab: Field to use for mapping word pieces to ids
        :param sort: sort based on descending order of length
        :param max_src_len : truncate at length ; 0 disables this
        :return: stream of DecoderBatches
        """
        log.info("Tokenizing sequences")
        buffer = []
        for i, line in enumerate(lines):
            line = line.strip()
            if not line:
                log.warning(
                    f"line {i + 1} was empty. inserting a dot (.). "
                    f"Empty lines are problematic when you want line-by-line alignment..."
                )
                line = "."
            cols = line.split('\t')
            id, ref = None, None
            if len(cols) == 1:  # SRC
                src = cols[0]
            elif len(cols) == 2:  # ID \t SRC
                id, src = cols
            else:  # ID \t SRC \t REF
                id, src, ref = cols[:3]
            seq = vocab.encode_as_ids(src, add_eos=True, add_bos=False)
            if max_src_len > 0 and len(seq) > max_src_len:
                log.warning(
                    f"Line {i} full length={len(seq)} ; truncated to {max_src_len}"
                )
                seq = seq[:max_src_len]
            buffer.append((i, src, ref, seq, id))  # idx, src, ref, seq, id

        if sort:
            log.info(f"Sorting based on the length. total = {len(buffer)}")
            buffer = sorted(buffer, reverse=True,
                            key=lambda x: len(x[3]))  # sort by length of seq

        batch = cls()
        batch.max_len_buffer = max_len_buffer
        for idx, src, ref, seq, _id in buffer:
            batch.add(idx=idx, src=src, ref=ref, seq=seq, id=_id)
            if batch.padded_tok_count >= batch_size:
                yield batch
                batch = cls()
                batch.max_len_buffer = max_len_buffer

        if batch.line_count > 0:
            yield batch
Ejemplo n.º 10
0
    def make_model(
            cls,
            src_vocab,
            tgt_vocab,
            enc_layers=6,
            dec_layers=6,
            hid_size=512,
            n_heads=8,
            attn_dropout=0.1,
            dropout=0.2,
            activation='relu',
            eff_dims: List[int] = (1024, 1024, 2048, 2048, 1024,
                                   1024),  # Using tuple for immutability
            dff_dims: List[int] = (1024, 1024, 2048, 2048, 1024, 1024),
            tied_emb='three-way',
            exp: Experiment = None):
        """Helper: Construct a model from hyper parameters."""

        assert enc_layers == len(eff_dims)
        assert dec_layers == len(dff_dims)

        # get all args for reconstruction at a later phase
        args = get_my_args(exclusions=['cls', 'exp'])
        assert activation in {'relu', 'elu', 'gelu'}
        log.info(f"Make model, Args={args}")

        if enc_layers == 0:
            log.warning("Zero encoder layers!")
        encoder = WidthVaryingEncoder(d_model=hid_size,
                                      ff_dims=eff_dims,
                                      N=enc_layers,
                                      n_heads=n_heads,
                                      attn_dropout=attn_dropout,
                                      dropout=dropout,
                                      activation=activation)

        assert dec_layers > 0
        decoder = WidthVaryingDecoder(d_model=hid_size,
                                      ff_dims=dff_dims,
                                      N=dec_layers,
                                      n_heads=n_heads,
                                      attn_dropout=attn_dropout,
                                      dropout=dropout,
                                      activation=activation)

        src_emb = nn.Sequential(Embeddings(hid_size, src_vocab),
                                PositionalEncoding(hid_size, dropout))
        tgt_emb = nn.Sequential(Embeddings(hid_size, tgt_vocab),
                                PositionalEncoding(hid_size, dropout))
        generator = Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)

        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Ejemplo n.º 11
0
 def backward(self, loss):
     if torch.isnan(loss):
         log.warning('loss is nan; backward() skipped')
         return
     if self.fp16:
         loss = self._scaler.scale(loss)
         # to apply norm: TODO: unscale gradients ; refer to docs
         # torch.nn.utils.clip_grad_norm_(self._amp.master_params(opt.optimizer), self.max_norm)
     loss.backward()
Ejemplo n.º 12
0
    def make_check_point(self,
                         train_loss: float,
                         val_loss: float,
                         keep_models: int,
                         log_embedding=False):
        """
        Check point the model
        :param train_loss: training loss value
        :param val_loss: loss on validation set
        :param keep_models: how many checkpoints to keep on file system
        :return:
        """

        step_num = self.opt.curr_step
        if step_num == self.last_step:
            log.warning("Ignoring checkpt request")
            return  # calling multiple times doesnt save
        log.info(
            f"Checkpoint at optimizer step {step_num}. Training Loss {train_loss:g},"
            f" Validation Loss:{val_loss:g}")
        self.show_samples()

        self.tbd.add_scalars(f'losses', {
            'train_loss': train_loss,
            'valid_loss': val_loss
        }, step_num)
        if log_embedding:
            # TODO: add metadata (text) of each subword
            # TODO: Update tag to include tie configuration
            self.tbd.add_embedding(self.model.generator.proj.weight,
                                   global_step=step_num,
                                   tag=f'Target embeddings')

        # Unwrap model state from DataParallel and persist
        model = (self.model.module
                 if hasattr(self.model, 'module') else self.model)
        state = {
            'model_state': model.state_dict(),
            'optim_state': self.opt.optimizer.state_dict(),
            'step': step_num,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'time': time.time(),
            'rtg_version': rtg.__version__,
            'model_type': self.exp.model_type,
            'model_args': self.exp.model_args,
        }
        if dtorch.fp16:
            state['amp_state'] = dtorch._scaler.state_dict()

        self.exp.store_model(step_num,
                             state,
                             train_score=train_loss,
                             val_score=val_loss,
                             keep=keep_models)
        self.last_step = step_num
Ejemplo n.º 13
0
    def _pre_process_parallel(self,
                              src_key: str,
                              tgt_key: str,
                              out_file: Path,
                              args: Optional[Dict[str, Any]] = None,
                              line_check=False,
                              **kwargs):
        """
        Pre process records of a parallel corpus
        :param args: all arguments for 'prep' task
        :param src_key: key that contains source sequences
        :param tgt_key: key that contains target sequences
        :param out_file: path to store processed TSV data (compresses if name ends with .gz)
        :return:
        """
        if kwargs:
            log.warning(f"The following args are ignored:{kwargs}")
        if not out_file.name.endswith(".nldb"):
            if 'train' in out_file.name:
                log.warning(f"set .nldb extension to enable spark")
            return super()._pre_process_parallel(src_key=src_key,
                                                 tgt_key=tgt_key,
                                                 out_file=out_file,
                                                 args=args,
                                                 line_check=line_check)

        args = args if args else self.config['prep']
        log.info(f"Going to prep files {src_key} and {tgt_key}")
        assert src_key in args, f'{src_key} not found in experiment config or args'
        assert tgt_key in args, f'{tgt_key} not found in experiment config or args'

        with log_resources(f"create {out_file.name}"):
            with spark_session(config=self.spark_conf) as spark:
                rdd, total = read_raw_parallel_recs(
                    spark,
                    src_path=args[src_key],
                    tgt_path=args[tgt_key],
                    truncate=args['truncate'],
                    src_len=args['src_len'],
                    tgt_len=args['tgt_len'],
                    src_tokenizer=self.src_vocab.encode_as_ids,
                    tgt_tokenizer=self.tgt_vocab.encode_as_ids)

                id_rdd = rdd.map(lambda r: (r[0],
                                            (r[1], r[2])))  # (id, (x, y))
                max_part_size = args.get('max_part_size', 1_000_000)
                n_parts = math.ceil(total / max_part_size)
                log.info(
                    f"Writing to {out_file}; {n_parts} parts,"
                    f" not exceeding {max_part_size:,} records in each part")

                rdd_as_db(id_rdd,
                          db_path=out_file,
                          field_names=['x', 'y'],
                          overwrite=True,
                          repartition=n_parts)
Ejemplo n.º 14
0
 def maybe_init_from_parent(self, exp: 'TranslationExperiment'):
     if exp.parent_model_state.exists():
         log.info("YES Initialising from a parent model")
         device = next(self.parameters()).device  # device of self model
         state = torch.load(exp.parent_model_state, map_location=device)
         error = self.load_state_dict(state, strict=False)
         log.info("YES Initialized from the parent model")
         if error.missing_keys or error.unexpected_keys:
             log.warning(f"Error keys: {error}")
     else:
         log.info("NOT initialising from parent model")
Ejemplo n.º 15
0
    def make_model(cls,
                   src_vocab,
                   tgt_vocab,
                   enc_layers=6,
                   dec_layers=6,
                   hid_size=512,
                   ff_size=2048,
                   n_heads=8,
                   dropout=0.1,
                   tied_emb='three-way',
                   activation='relu',
                   exp: Experiment = None):
        "Helper: Construct a model from hyper parameters."

        # get all args for reconstruction at a later phase
        _, _, _, args = inspect.getargvalues(inspect.currentframe())
        for exclusion in ['cls', 'exp']:
            del args[exclusion]  # exclude some args
        # In case you are wondering, why I didnt use **kwargs here:
        #   these args are read from conf file where user can introduce errors, so the parameter
        #   validation and default value assignment is implicitly done by function call for us :)
        assert activation in {'relu', 'elu', 'gelu'}
        log.info(f"Make model, Args={args}")
        c = copy.deepcopy
        attn = MultiHeadedAttention(n_heads, hid_size, dropout=dropout)
        ff = PositionwiseFeedForward(hid_size,
                                     ff_size,
                                     dropout,
                                     activation=activation)

        if enc_layers == 0:
            log.warning("Zero encoder layers!")
        encoder = Encoder(EncoderLayer(hid_size, c(attn), c(ff), dropout),
                          enc_layers)

        assert dec_layers > 0
        decoder = Decoder(
            DecoderLayer(hid_size, c(attn), c(attn), c(ff), dropout),
            dec_layers)

        src_emb = nn.Sequential(Embeddings(hid_size, src_vocab),
                                PositionalEncoding(hid_size, dropout))
        tgt_emb = nn.Sequential(Embeddings(hid_size, tgt_vocab),
                                PositionalEncoding(hid_size, dropout))
        generator = Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)

        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Ejemplo n.º 16
0
Archivo: utils.py Proyecto: isi-nlp/rtg
 def safe_delete(cls, path: Path):
     try:
         if path.exists():
             if path.is_file():
                 log.info(f"Delete file {path}")
                 path.unlink()
             elif path.is_dir():
                 log.info(f"Delete dir {path}")
                 path.rmdir()
             else:
                 log.warning(f"Coould not delete {path}")
     except:
         log.exception(f"Error while clearning up {path}")
Ejemplo n.º 17
0
 def __iter__(self) -> Iterator[IdExample]:
     for d in self.read_all():
         id, x, y = d['id'], d['x'], d.get('y')
         if x is None or y is None or len(x) == 0 or len(y) == 0:
             log.warning(
                 f"Ignoring an empty record   x:{len(x)}    y:{len(y)}")
             continue
         if len(x) > self.max_src_len or len(y) > self.max_tgt_len:
             if self.truncate:
                 x, y = x[:self.max_src_len], y[:self.max_tgt_len]
             else:  # skip this record
                 continue
         yield IdExample(x=x, y=y, id=id)
Ejemplo n.º 18
0
 def __iter__(self) -> Iterator[Example]:
     for d in self.read_all():
         x, y = d['x'], d.get('y')
         if not x or not y:
             log.warning(
                 f"Ignoring an empty record   x:{len(x)}    y:{len(y)}")
             continue
         if len(x) > self.max_src_len or len(y) > self.max_tgt_len:
             if self.truncate:
                 x, y = x[:self.max_src_len], y[:self.max_tgt_len]
             else:  # skip this record
                 continue
         yield Example(x, y)
Ejemplo n.º 19
0
 def __init__(self,
              model: TransformerNMT,
              field,
              x_seqs,
              x_lens=None,
              multi_label=False):
     super().__init__(model, field)
     self.x_mask = (x_seqs != field.pad_idx).unsqueeze(1)
     self.memory = self.model.encode(x_seqs, self.x_mask)
     self.multi_label = multi_label
     if multi_label and not type(self).multi_label_warned:
         log.warning(">>> Multi-label decoding mode enabled")
         type(self).multi_label_warned = True
Ejemplo n.º 20
0
def run_command(cmd_line: str, fail_on_error=True):
    log.info(f'RUN:: {cmd_line}')
    proc = subprocess.run(cmd_line, shell=True, capture_output=True, text=True)
    if proc.returncode == 0:
        log.info(f"STDOUT={proc.stdout}")
        if proc.stderr:
            log.warning(f"STDERR={proc.stderr}")
    else:
        msg = f'CMD={cmd_line}\nCODE={proc.returncode}\nSTDOUT={proc.stdout}\nSTDERR={proc.stderr}'
        if fail_on_error:
            raise Exception('subprocess failed\n' + msg)
        else:
            log.warning(msg)
    return proc.returncode == 0
Ejemplo n.º 21
0
Archivo: exp.py Proyecto: MGheini/rtg
    def store_model(self,
                    epoch: int,
                    model,
                    train_score: float,
                    val_score: float,
                    keep: int,
                    prefix='model',
                    keeper_sort='step'):
        """
        saves model to a given path
        :param epoch: epoch number of model
        :param model: model object itself
        :param train_score: score of model on training split
        :param val_score: score of model on validation split
        :param keep: number of good models to keep, bad models will be deleted
        :param prefix: prefix to store model. default is "model"
        :param keeper_sort: criteria for choosing the old or bad models for deletion.
            Choices: {'total_score', 'step'}
        :return:
        """
        # TODO: improve this by skipping the model save if the model is not good enough to be saved
        if self.read_only:
            log.warning("Ignoring the store request; experiment is readonly")
            return
        name = f'{prefix}_{epoch:03d}_{train_score:.6f}_{val_score:.6f}.pkl'
        path = self.model_dir / name
        log.info(f"Saving epoch {epoch} to {path}")
        torch.save(model, str(path))

        del_models = []
        if keeper_sort == 'total_score':
            del_models = self.list_models(sort='total_score',
                                          desc=False)[keep:]
        elif keeper_sort == 'step':
            del_models = self.list_models(sort='step', desc=True)[keep:]
        else:
            Exception(f'Sort criteria{keeper_sort} not understood')
        for d_model in del_models:
            log.info(
                f"Deleting model {d_model} . Keep={keep}, sort={keeper_sort}")
            os.remove(str(d_model))

        with IO.writer(os.path.join(self.model_dir, 'scores.tsv'),
                       append=True) as f:
            cols = [
                str(epoch),
                datetime.now().isoformat(), name, f'{train_score:g}',
                f'{val_score:g}'
            ]
            f.write('\t'.join(cols) + '\n')
Ejemplo n.º 22
0
def main():
    # No grads required
    torch.set_grad_enabled(False)
    args = parse_args()
    gen_args = {}
    exp = Experiment(args.pop('work_dir'), read_only=True)
    validate_args(args, exp)

    if exp.model_type == 'binmt':
        if not args.get('path'):
            Exception('--binmt-path argument is needed for BiNMT model.')
        gen_args['path'] = args.pop('binmt_path')

    weights = args.get('weights')
    if weights:
        decoder = Decoder.combo_new(exp,
                                    model_paths=args.pop('model_path'),
                                    weights=weights)
    else:
        decoder = Decoder.new(exp,
                              gen_args=gen_args,
                              model_paths=args.pop('model_path', None),
                              ensemble=args.pop('ensemble', 1))
    if args.pop('interactive'):
        if weights:
            log.warning(
                "Interactive shell not reloadable for combo mode. FIXME: TODO:"
            )
        if args['input'] != sys.stdin or args['output'] != sys.stdout:
            log.warning(
                '--input and --output args are not applicable in --interactive mode'
            )
        args.pop('input')
        args.pop('output')

        while True:
            try:
                # an hacky way to unload and reload model when user tries to switch models
                decoder.decode_interactive(**args)
                break  # exit loop if there is no request for reload
            except ReloadEvent as re:
                decoder = Decoder.new(exp,
                                      gen_args=gen_args,
                                      model_paths=re.model_paths)
                args = re.state
                # go back to loop and redo interactive shell
    else:
        return decoder.decode_file(args.pop('input'), args.pop('output'),
                                   **args)
Ejemplo n.º 23
0
    def train(self, steps: int, check_point: int, batch_size: int,
              check_pt_callback: Optional[Callable] = None, **args):
        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)
        if self.start_step >= steps:
            log.warning(f"Already trained to  {self.start_step}. Considering it as done.")
            return
        rem_steps = steps - self.start_step
        side = 'tgt'     # TODO: this should be inferrable or configurable instead of hardcoded

        train_data = self.exp.get_mono_data('train', side, batch_size=batch_size,
                                            batch_first=True, sort_dec=True,
                                            num_batches=rem_steps, shuffle=True)
        val_data = self.exp.get_mono_data('valid', side, batch_size=batch_size,
                                          batch_first=True, sort_dec=True)

        keep_models = 8
        unsaved_state = False
        with tqdm(train_data, initial=self.start_step, total=steps, unit='batch',
                  dynamic_ncols=True) as data_bar:
            for batch in data_bar:
                batch.to(device)
                outp_log_probs = self.model.batch_forward(batch)
                loss = self.simple_loss_func(outp_log_probs, seq_lens=batch.x_len,
                                             tot_toks=batch.x_toks, max_seq_len=batch.max_x_len,
                                             train_mode=True)
                unsaved_state = True
                bar_msg, is_check_pt = train_state.step(batch.x_toks, loss)
                data_bar.set_postfix_str(bar_msg, refresh=True)
                del batch       # TODO: force free memory
                if is_check_pt:
                    train_loss = train_state.reset()
                    train_state.train_mode(False)
                    val_loss = self.run_valid_epoch(val_data)
                    self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
                    if check_pt_callback:
                        check_pt_callback(model=self.model,
                                          step=self.opt.curr_step,
                                          train_loss=train_loss)
                    train_state.train_mode(True)
                    unsaved_state = False

        log.info("End of training session")
        if unsaved_state:
            # End of training
            train_loss = train_state.reset()
            train_state.train_mode(False)
            val_loss = self.run_valid_epoch(val_data)
            self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
Ejemplo n.º 24
0
def setup_jdbc():
    # add postgres jdbc driver
    import pyspark
    PG_JDBC_URL = 'https://jdbc.postgresql.org/download/postgresql-42.2.14.jar'
    jars_dir = Path(pyspark.__file__).parent / 'jars'
    if jars_dir.exists():
        pg_jdbc_jars = list(jars_dir.glob("postgresql-*.jar"))
        if pg_jdbc_jars:
            log.info(f'Located JDBC jar for postgres: {pg_jdbc_jars}')
        else:
            jar_path = jars_dir / (PG_JDBC_URL.split('/')[-1])
            download_file(PG_JDBC_URL, jar_path, fair_on_error=False)
    else:
        log.warning(
            "pyspark jars are not detected. "
            "You may need to manually configure postgres JDBC to spark config")
Ejemplo n.º 25
0
    def __init__(self, d_model, d_ff, dropout=0.1, activation='relu'):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        activations = dict(relu=F.relu, elu=F.elu)

        if activation == 'gelu':  # TODO: when torch 1.2 comes out, simplify this block
            try:
                activations['gelu'] = F.gelu
            except:
                log.warning(
                    f"gelu is not available. using torch {torch.__version__}."
                    f" GELU was added in https://github.com/pytorch/pytorch/pull/20665"
                )
                raise
        self.activation = activations[activation]
Ejemplo n.º 26
0
 def read_all(self) -> Iterator[IdExample]:
     with IO.reader(self.path) as lines:
         recs = (line.split('\t') for line in lines)
         for idx, rec in enumerate(recs):
             x = self._parse(rec[0].strip())
             y = self._parse(rec[1].strip()) if len(rec) > 1 else None
             if self.truncate:  # truncate long recs
                 x = x[:self.max_src_len]
                 y = y if y is None else y[:self.max_tgt_len]
             elif len(x) > self.max_src_len or (0 if y is None else
                                                len(y)) > self.max_tgt_len:
                 continue  # skip long recs
             if not x or (y is not None
                          and len(y) == 0):  # empty on one side
                 log.warning(
                     f"Ignoring an empty record  x:{len(x)}    y:{len(y)}")
                 continue
             yield IdExample(x, y, id=idx)
Ejemplo n.º 27
0
Archivo: exp.py Proyecto: MGheini/rtg
        def _prep_file(file_key, out_file, do_truncate, max_len, field: Field):
            if file_key not in args:
                log.warning(
                    f'Skipped: {file_key} because it is not found in config')
                return

            raw_file = args[file_key]

            recs = TSVData.read_raw_mono_recs(raw_file, do_truncate, max_len,
                                              field.encode_as_ids)
            # TODO: use SQLite storage
            TSVData.write_mono_recs(recs, out_file)
            if args.get('text_files'):
                recs = TSVData.read_raw_mono_recs(raw_file, do_truncate,
                                                  max_len, field.tokenize)
                TSVData.write_mono_recs(
                    recs,
                    str(out_file).replace('.tsv', '.pieces.tsv'))
Ejemplo n.º 28
0
    def make_check_point(self, train_loss: float, val_loss: float,
                         keep_models: int):
        """
        Check point the model
        :param train_loss: training loss value
        :param val_loss: loss on validation set
        :param keep_models: how many checkpoints to keep on file system
        :return:
        """

        step_num = self.opt.curr_step
        if step_num == self.last_step:
            log.warning("Ignoring checkpt request")
            return  # calling multiple times doesnt save
        log.info(
            f"Checkpoint at step {step_num}. Training Loss {train_loss:g},"
            f" Validation Loss:{val_loss:g}")
        self.show_samples()

        self.tbd.add_scalars(f'losses', {
            'train_loss': train_loss,
            'valid_loss': val_loss
        }, step_num)
        # Unwrap model state from DataParallel and persist
        model = (self.model.module
                 if isinstance(self.model, nn.DataParallel) else self.model)
        state = {
            'model_state': model.state_dict(),
            'optim_state': self.opt.optimizer.state_dict(),
            'step': step_num,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'time': time.time(),
            'rtg_version': rtg.__version__,
            'model_type': self.exp.model_type,
            'model_args': self.exp.model_args,
        }

        self.exp.store_model(step_num,
                             state,
                             train_score=train_loss,
                             val_score=val_loss,
                             keep=keep_models)
        self.last_step = step_num
Ejemplo n.º 29
0
def main():
    args = parse_args()
    conf_file: Path = args.conf_file if args.conf_file else args.work_dir / 'conf.yml'
    assert conf_file.exists()
    ExpFactory = TranslationExperiment
    is_big = load_conf(conf_file).get('spark', {})
    if is_big:
        log.info("Big experiment mode enabled; checking pyspark backend")
        try:
            import pyspark
            log.info("pyspark is available")
        except:
            log.warning("unable to import pyspark. Please do 'pip install pyspark' and run again")
            raise
        from rtg.big.exp import BigTranslationExperiment
        ExpFactory = BigTranslationExperiment

    exp = ExpFactory(args.exp, config=conf_file, read_only=False)
    return exp.pre_process()
Ejemplo n.º 30
0
    def from_lines(cls,
                   lines: Iterator[str],
                   batch_size: int,
                   vocab: Field,
                   sort=True):
        """
        Note: this changes the order based on sequence length if sort=True
        :param lines: stream of input lines
        :param batch_size: number of tokens in batch
        :param vocab: Field to use for mapping word pieces to ids
        :param sort: sort based on descending order of length
        :return: stream of DecoderBatches
        """
        log.info("Tokenizing sequences")
        buffer = []
        for i, line in enumerate(lines):
            line = line.strip()
            if not line:
                log.warning(
                    f"line {i + 1} was empty. inserting a dot (.). "
                    f"Empty lines are problematic when you want line-by-line alignment..."
                )
                line = "."
            cols = line.split('\t')
            seq = vocab.encode_as_ids(line, add_eos=True, add_bos=False)
            ref = cols[1] if len(cols) > 1 else None
            buffer.append((i, cols[0], ref, seq))  #idx, src, ref, seq

        if sort:
            log.info(f"Sorting based on the length. total = {len(buffer)}")
            buffer = sorted(buffer, reverse=True,
                            key=lambda x: len(x[-1]))  # sort by length of seq

        batch = cls()
        for idx, src, ref, seq in buffer:
            batch.add(idx=idx, src=src, ref=ref, seq=seq)
            if batch.padded_tok_count >= batch_size:
                yield batch
                batch = cls()

        if batch.line_count > 0:
            yield batch