Exemplo n.º 1
0
    def __init__(self,
                 exp: Experiment,
                 model: Optional[TransformerNMT] = None,
                 optim: str = 'ADAM',
                 model_factory=TransformerNMT.make_model,
                 **optim_args):
        super().__init__(exp,
                         model,
                         model_factory=model_factory,
                         optim=optim,
                         **optim_args)
        trainer_args = self.exp.config.get('trainer', {}).get('init_args', {})
        chunk_size = trainer_args.get('chunk_size', -1)
        self.grad_accum_interval = trainer_args.get('grad_accum', 1)
        assert self.grad_accum_interval > 0

        if self.n_gpus > 1:  # Multi GPU mode
            raise Exception(
                f"Please use: python -m rtg.distrib.launch -G {self.n_gpus} \n "
                f" or set single GPU by: export CUDA_VISIBLE_DEVICES=0 ")

        generator = self.core_model.generator
        if not chunk_size or chunk_size < 1:
            self.loss_func = SimpleLossFunction(generator=generator,
                                                criterion=self.criterion,
                                                opt=self.opt)
        else:
            log.info(f"Using Chunked Loss Generator. chunk_size={chunk_size}")
            self.loss_func = ChunkedLossCompute(generator=generator,
                                                criterion=self.criterion,
                                                opt=self.opt,
                                                chunk_size=chunk_size)
Exemplo n.º 2
0
Arquivo: dummy.py Projeto: MGheini/rtg
def write_tsv(data, out):
    count = 0
    for src_seq, tgt_seq in data:
        src_seq, tgt_seq = ' '.join(map(str, src_seq)), ' '.join(map(str, tgt_seq))
        out.write(f'{src_seq}\t{tgt_seq}\n')
        count += 1
    log.info(f"Wrote {count} records")
Exemplo n.º 3
0
 def _make_vocab(self,
                 name: str,
                 vocab_file: Path,
                 model_type: str,
                 vocab_size: int,
                 corpus: List,
                 no_split_toks: List[str] = None,
                 char_coverage=0) -> Field:
     if vocab_file.exists():
         log.info(
             f"{vocab_file} exists. Skipping the {name} vocab creation")
         return self.Field(str(vocab_file))
     with log_resources(f"create vocab {name}"):
         flat_uniq_corpus = set(
         )  # remove dupes, flat the nested list or sets
         for i in corpus:
             if isinstance(i, set) or isinstance(i, list):
                 flat_uniq_corpus.update(i)
             else:
                 flat_uniq_corpus.add(i)
         with spark_session(config=self.spark_conf) as spark:
             flat_uniq_corpus = list(flat_uniq_corpus)
             log.info(
                 f"Going to build {name} vocab from {len(flat_uniq_corpus)} files "
             )
             return self.Field.train(model_type,
                                     vocab_size,
                                     str(vocab_file),
                                     flat_uniq_corpus,
                                     no_split_toks=no_split_toks,
                                     char_coverage=char_coverage,
                                     spark=spark)
Exemplo n.º 4
0
    def make_model(cls,
                   src_vocab,
                   tgt_vocab,
                   n_layers=6,
                   hid_size=512,
                   ff_size=2048,
                   n_heads=8,
                   attn_dropout=0.1,
                   dropout=0.1,
                   activation='relu',
                   tied_emb='three-way',
                   plug_mode='cat_attn',
                   exp: Experiment = None):
        """
        Helper: Construct a model from hyper parameters."
        :return: model, args
        """
        assert plug_mode in {'cat_attn', 'add_attn', 'cat_emb'}
        # get all args for reconstruction at a later phase
        _, _, _, args = inspect.getargvalues(inspect.currentframe())
        for exclusion in ['cls', 'exp']:
            del args[exclusion]  # exclude some args
        # In case you are wondering, why I didnt use **kwargs here:
        #   these args are read from conf file where user can introduce errors, so the parameter
        #   validation and default value assignment is implicitly done by function call for us :)
        log.info(f"making mtfmnmt model: {args}")
        c = copy.deepcopy
        attn = MultiHeadedAttention(n_heads, hid_size, dropout=attn_dropout)
        ff = PositionwiseFeedForward(hid_size,
                                     ff_size,
                                     dropout,
                                     activation=activation)

        enc_layer = EncoderLayer(hid_size, c(attn), c(ff), dropout)
        encoder = Encoder(enc_layer, n_layers)  # clones n times
        src_emb = nn.Sequential(Embeddings(hid_size, src_vocab),
                                PositionalEncoding(hid_size, dropout))

        if plug_mode == 'cat_emb':
            tgt_emb = nn.Sequential(MEmbeddings(hid_size, tgt_vocab),
                                    PositionalEncoding(hid_size, dropout))
            decoder = c(
                encoder
            )  # decoder is same as encoder, except embeddings have concat
        else:
            dec_block = DecoderBlock(hid_size, dropout, mode=plug_mode)
            dec_layer = MDecoderLayer(hid_size, c(attn), c(dec_block), c(ff),
                                      dropout)
            decoder = MDecoder(dec_layer, n_layers)
            tgt_emb = nn.Sequential(Embeddings(hid_size, tgt_vocab),
                                    PositionalEncoding(hid_size, dropout))

        generator = Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)
        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Exemplo n.º 5
0
    def __init__(self,
                 models: List[Path],
                 exp: Union[Path, TranslationExperiment],
                 lr: float = 1e-4,
                 smoothing=0.1):
        if isinstance(exp, Path):
            exp = TranslationExperiment(exp)
        self.w_file = exp.work_dir / f'combo-weights.yml'

        wt = None
        if self.w_file.exists():
            with IO.reader(self.w_file) as rdr:
                combo_spec = yaml.load(rdr)
            weights = combo_spec['weights']
            assert len(weights) == len(
                models)  # same models as before: no messing allowed
            model_path_strs = [str(m) for m in models]
            for m in model_path_strs:
                assert m in weights, f'{m} not found in weights file.'
            wt = [weights[str(m)] for m in model_path_strs]
            log.info(f"restoring previously stored weights {wt}")

        from rtg.module.decoder import load_models
        combo = Combo(load_models(models, exp), model_paths=models, w=wt)
        self.combo = combo.to(device)
        self.exp = exp
        self.optim = torch.optim.Adam(combo.parameters(), lr=lr)
        self.criterion = LabelSmoothing(vocab_size=combo.vocab_size,
                                        padding_idx=PAD_TOK_IDX,
                                        smoothing=smoothing)
Exemplo n.º 6
0
 def map_rows(cls, src, dest, mapping):
     skips = 0
     for dest_idx, src_idx in mapping.items():
         if dest_idx < 0 or src_idx < 0:
             skips += 1
         dest[dest_idx] = src[src_idx]
     log.info(f"Mapped rows. Total skips = {skips}")
Exemplo n.º 7
0
def read_bitext(spark,
                src_file: Union[str, Path],
                tgt_file: Union[str, Path],
                src_name='src_raw',
                tgt_name='tgt_raw') -> Tuple[DataFrame, int]:
    if not isinstance(src_file, str):
        src_file = str(src_file)
    if not isinstance(tgt_file, str):
        tgt_file = str(tgt_file)

    src_df = spark.read.text(src_file).withColumnRenamed('value', src_name)
    tgt_df = spark.read.text(tgt_file).withColumnRenamed('value', tgt_name)

    n_src, n_tgt = src_df.count(), tgt_df.count()
    assert n_src == n_tgt, f'{n_src} == {n_tgt} ?'
    log.info(f"Found {n_src:,} parallel records in {src_file, tgt_file}")

    def with_idx(sdf):
        new_schema = StructType(sdf.schema.fields + [
            StructField("idx", LongType(), False),
        ])
        return sdf.rdd.zipWithIndex().map(lambda row: row[0] +
                                          (row[1], )).toDF(schema=new_schema)

    src_df = with_idx(src_df)
    tgt_df = with_idx(tgt_df)
    bitext_df = src_df.join(tgt_df, 'idx', "inner")
    # n_bitext = bitext_df.count()
    # assert n_bitext == n_src, f'{n_bitext} == {n_src} ??'
    return bitext_df, n_src
Exemplo n.º 8
0
def main():
    args = parse_args()
    seed = args.pop("seed")
    if seed:
        log.info(f"Seed for random number generator: {seed}")
        import random
        import torch
        random.seed(seed)
        torch.manual_seed(seed)

    work_dir = Path(args.pop('work_dir'))
    is_big = load_conf(work_dir / 'conf.yml').get('spark', {})

    if is_big:
        log.info("Big experiment mode enabled; checking pyspark backend")
        try:
            import pyspark
        except:
            log.warning("unable to import pyspark. Please do 'pip install pyspark' and run again")
            raise
        from rtg.big.exp import BigTranslationExperiment
        exp = BigTranslationExperiment(work_dir=work_dir)
    else:
        exp = Experiment(work_dir=work_dir)
    assert exp.has_prepared(), f'Experiment dir {exp.work_dir} is not ready to train. ' \
                               f'Please run "prep" sub task'
    exp.train(args)
Exemplo n.º 9
0
    def write_df(self, df, table_name: str, mode="overwrite"):

        log.info(
            f"writing dataframe to {self.url}; table={table_name} mode={mode}")
        return (df.write.mode(mode).format("jdbc").option(
            "url", self.url).option("dbtable", table_name).option(
                "driver", "org.postgresql.Driver").save())
Exemplo n.º 10
0
Arquivo: exp.py Projeto: MGheini/rtg
    def _make_vocab(self,
                    name: str,
                    vocab_file: Path,
                    model_type: str,
                    vocab_size: int,
                    corpus: List,
                    no_split_toks: List[str] = None) -> Field:
        """
        Construct vocabulary file
        :param name: name : src, tgt or shared -- for the sake of logging
        :param vocab_file: where to save the vocab file
        :param model_type: sentence piece model type
        :param vocab_size: max types in vocab
        :param corpus: as the name says, list of files from which the vocab should be learned
        :param no_split_toks: tokens that needs to be preserved from splitting, or added
        :return:
        """
        if vocab_file.exists():
            log.info(
                f"{vocab_file} exists. Skipping the {name} vocab creation")
            return Field(str(vocab_file))
        flat_uniq_corpus = set()  # remove dupes, flat the nested list or sets
        for i in corpus:
            if isinstance(i, set) or isinstance(i, list):
                flat_uniq_corpus.update(i)
            else:
                flat_uniq_corpus.add(i)

        log.info(f"Going to build {name} vocab from mono files")
        return Field.train(model_type,
                           vocab_size,
                           str(vocab_file),
                           flat_uniq_corpus,
                           no_split_toks=no_split_toks)
Exemplo n.º 11
0
    def make_model(cls, src_vocab, tgt_vocab, enc_layers=6, hid_size=512, ff_size=2048, enc_heads=8,
                   dropout=0.1, tied_emb='three-way', dec_rnn_type: str = 'LSTM',
                   dec_layers: int = 1,
                   exp: Experiment = None):
        """
        Helper: Construct a model from hyper parameters."
        :return: model, args
        """
        # get all args for reconstruction at a later phase
        _, _, _, args = inspect.getargvalues(inspect.currentframe())
        for exclusion in ['cls', 'exp']:
            del args[exclusion]  # exclude some args

        log.info(f"making hybridmt model: {args}")

        c = copy.deepcopy
        attn = MultiHeadedAttention(enc_heads, hid_size)
        ff = PositionwiseFeedForward(hid_size, ff_size, dropout)

        enc_layer = EncoderLayer(hid_size, c(attn), c(ff), dropout)
        encoder = Encoder(enc_layer, enc_layers)  # clones n times
        src_emb = nn.Sequential(Embeddings(hid_size, src_vocab),
                                PositionalEncoding(hid_size, dropout))

        decoder = RnnDecoder(rnn_type=dec_rnn_type, hid_size=hid_size, n_layers=dec_layers,
                             dropout=dropout)
        tgt_emb = Embeddings(hid_size, tgt_vocab)
        generator = Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)
        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Exemplo n.º 12
0
    def train(self,
              steps: int,
              check_point: int,
              batch_size: int,
              check_pt_callback: Optional[Callable] = None,
              side='tgt',
              ctx_size=2,
              **args):
        log.info(f"using side={side}, ctx_size={ctx_size}")
        reader = DataReader(self.exp, side=side)
        rem_steps = steps - self.start_step
        if rem_steps <= 0:
            log.warning(f"Already trained upto {self.start_step-1}. Skipped")
            return
        train_data = reader.get_training_data(batch_size=batch_size,
                                              n_batches=rem_steps,
                                              ctx_size=ctx_size)
        val_data = reader.get_val_data(batch_size=batch_size,
                                       ctx_size=ctx_size)
        train_loss, n = 0.0, 0

        def _make_checkpt(step, train_loss):
            with torch.no_grad():
                val_loss = self.run_valid_epoch(val_data)
            log.info(f"Checkpoint at {step}")
            self.save_embeddings(step, train_loss, val_loss, txt=True)
            self.tbd.add_scalars('losses', {
                'training': train_loss,
                'valid_loss': val_loss
            },
                                 global_step=step)

        with tqdm(train_data,
                  initial=self.start_step,
                  total=rem_steps + 1,
                  unit='batch',
                  dynamic_ncols=True) as data_bar:
            for i, (xs, ys) in enumerate(data_bar, start=self.start_step):
                self.model.zero_grad()
                xs, ys = xs.to(device), ys.to(device)
                log_probs = self.model(xs)
                loss = self.loss_func(log_probs, ys)
                self.tbd.add_scalars('training', {
                    'step_loss': loss.item(),
                    'learn_rate': self.opt.curr_lr
                }, self.opt.curr_step)
                progress_msg = f', loss={loss:g} LR={self.opt.curr_lr:g}'
                data_bar.set_postfix_str(progress_msg, refresh=False)
                train_loss += loss.item()
                n += len(ys)

                loss.backward()
                self.opt.step()
                self.opt.zero_grad()

                if i > 0 and i % check_point == 0:
                    _make_checkpt(i, train_loss / n)
                    train_loss, n = 0.0, 0
        if n > 0:
            _make_checkpt(steps, train_loss / n)
Exemplo n.º 13
0
 def _make_checkpt(step, train_loss):
     with torch.no_grad():
         val_loss = self.run_valid_epoch(val_data)
     log.info(f"Checkpoint at {step}")
     self.save_embeddings(step, train_loss, val_loss, txt=True)
     self.tbd.add_scalars('losses', {'training': train_loss,
                                     'valid_loss': val_loss}, global_step=step)
Exemplo n.º 14
0
 def enable_fp16(self):
     if not self.fp16:  # conditional import
         self.fp16 = True
         self._scaler = GradScaler(enabled=self.fp16)
         log.info("Enabling FP16  /Automatic Mixed Precision training")
     else:
         log.warning(" fp16 is already enabled")
Exemplo n.º 15
0
    def __init__(self,
                 data_path: Union[str, Path],
                 batch_size: int,
                 sort_desc: bool = False,
                 batch_first: bool = True,
                 shuffle: bool = False,
                 sort_by: str = None,
                 **kwargs):
        """
        Iterator for reading training data in batches
        :param data_path: path to TSV file
        :param batch_size: number of tokens on the target size per batch

        :param sort_desc: should the batch be sorted by src sequence len (useful for RNN api)
        """
        self.sort_desc = sort_desc
        self.batch_size = batch_size
        self.batch_first = batch_first
        self.sort_by = sort_by
        if not isinstance(data_path, Path):
            data_path = Path(data_path)
        if data_path.name.endswith(".db"):
            self.data = SqliteFile(data_path, sort_by=sort_by, **kwargs)
        else:
            if sort_by:
                raise Exception(
                    f'sort_by={sort_by} not supported for TSV data')
            self.data = TSVData(data_path,
                                shuffle=shuffle,
                                longest_first=False,
                                **kwargs)
        log.info(f'Batch Size = {batch_size} toks, sort_by={sort_by}')
Exemplo n.º 16
0
    def __init__(self,
                 path: Path,
                 sort_by='random',
                 len_rand=2,
                 max_src_len: int = 512,
                 max_tgt_len: int = 512,
                 truncate: bool = False):

        log.info(f"{type(self)} Args: {get_my_args()}")
        self.path = path
        assert path.exists()
        self.select_qry = self.make_query(sort_by, len_rand=len_rand)
        self.max_src_len, self.max_tgt_len = max_src_len, max_tgt_len
        self.truncate = truncate
        self.db = sqlite3.connect(str(path))

        def dict_factory(cursor,
                         row):  # map tuples to dictionary with column names
            d = {}
            for idx, col in enumerate(cursor.description):
                key = col[0]
                val = row[idx]
                if key in ('x', 'y') and val is not None:
                    val = pickle.loads(val)  # unmarshall
                d[key] = val
            return d

        self.db.row_factory = dict_factory
Exemplo n.º 17
0
    def __init__(self,
                 path: Path,
                 sort_by='random',
                 len_rand=2,
                 max_src_len: int = 512,
                 max_tgt_len: int = 512,
                 truncate: bool = False):

        log.info(f"{type(self)} Args: {get_my_args()}")
        self.path = path
        assert path.exists()
        self.select_qry = self.make_query(sort_by, len_rand=len_rand)
        self.max_src_len, self.max_tgt_len = max_src_len, max_tgt_len
        self.truncate = truncate
        self.db = sqlite3.connect(str(path))
        self.db_version = self.db.execute('PRAGMA user_version;').fetchone()[0]

        def dict_factory(cursor,
                         row):  # map tuples to dictionary with column names
            d = {}
            for idx, col in enumerate(cursor.description):
                key = col[0]
                val = row[idx]
                if key in ('x', 'y') and val is not None:
                    if self.db_version < 1:
                        val = pickle.loads(val)  # unmarshall
                        val = np.array(val, dtype=np.int32)
                    else:  # version 1 and above
                        val = np.frombuffer(val, dtype=np.int32)
                d[key] = val
            return d

        self.db.row_factory = dict_factory
Exemplo n.º 18
0
Arquivo: exp.py Projeto: MGheini/rtg
    def get_train_data(self,
                       batch_size: int,
                       steps: int = 0,
                       sort_by='eq_len_rand_batch',
                       batch_first=True,
                       shuffle=False,
                       fine_tune=False):
        inp_file = self.train_db if self.train_db.exists() else self.train_file
        if fine_tune:
            if not self.finetune_file.exists():
                # user may have added fine tune file later
                self._pre_process_parallel('finetune_src', 'finetune_tgt',
                                           self.finetune_file)
            log.info("Using Fine tuning corpus instead of training corpus")
            inp_file = self.finetune_file

        train_data = BatchIterable(inp_file,
                                   batch_size=batch_size,
                                   sort_by=sort_by,
                                   batch_first=batch_first,
                                   shuffle=shuffle,
                                   **self._get_batch_args())
        if steps > 0:
            train_data = LoopingIterable(train_data, steps)
        return train_data
Exemplo n.º 19
0
    def write(cls, path, records: Iterator[ParallelSeqRecord]):
        if path.exists():
            log.warning(f"Overwriting {path} with new records")
            os.remove(str(path))
        maybe_tmp = IO.maybe_tmpfs(path)
        log.info(f'Creating {maybe_tmp}')
        conn = sqlite3.connect(str(maybe_tmp))
        cur = conn.cursor()
        cur.execute(cls.TABLE_STATEMENT)
        cur.execute(cls.INDEX_X_LEN)
        cur.execute(cls.INDEX_Y_LEN)
        cur.execute(f"PRAGMA user_version = {cls.CUR_VERSION};")

        count = 0
        for x_seq, y_seq in records:
            # use numpy. its a lot efficient
            if not isinstance(x_seq, np.ndarray):
                x_seq = np.array(x_seq, dtype=np.int32)
            if y_seq is not None and not isinstance(y_seq, np.ndarray):
                y_seq = np.array(y_seq, dtype=np.int32)
            values = (x_seq.tobytes(),
                      None if y_seq is None else y_seq.tobytes(), len(x_seq),
                      len(y_seq) if y_seq is not None else -1)
            cur.execute(cls.INSERT_STMT, values)
            count += 1
        cur.close()
        conn.commit()
        if maybe_tmp != path:
            # bring the file back to original location where it should be
            IO.copy_file(maybe_tmp, path)
        log.info(f"stored {count} rows in {path}")
Exemplo n.º 20
0
    def make_model(cls,
                   src_vocab,
                   tgt_vocab,
                   enc_layers=6,
                   dec_layers=6,
                   hid_size=512,
                   ff_size=2048,
                   n_heads=8,
                   attn_bias=True,
                   attn_dropout=0.1,
                   dropout=0.2,
                   activation='relu',
                   enc_depth_probs: List[float] = (1.0, 0.9, 0.8, 0.7, 0.6,
                                                   0.5),
                   dec_depth_probs: List[float] = (1.0, 0.9, 0.8, 0.7, 0.6,
                                                   0.5),
                   tied_emb='three-way',
                   exp: Experiment = None):
        """Helper: Construct a model from hyper parameters."""
        assert len(enc_depth_probs) == enc_layers
        assert len(dec_depth_probs) == dec_layers
        # get all args for reconstruction at a later phase
        args = get_my_args(exclusions=['cls', 'exp'])

        assert activation in {'relu', 'elu', 'gelu'}
        log.info(f"Make model, Args={args}")
        c = copy.deepcopy
        attn = tfm.MultiHeadedAttention(n_heads,
                                        hid_size,
                                        dropout=attn_dropout,
                                        bias=attn_bias)
        ff = tfm.PositionwiseFeedForward(hid_size,
                                         ff_size,
                                         dropout,
                                         activation=activation)

        if enc_layers == 0:
            log.warning("Zero encoder layers!")
        encoder = SkipEncoder(
            tfm.EncoderLayer(hid_size, c(attn), c(ff), dropout), enc_layers,
            enc_depth_probs)

        assert dec_layers > 0
        decoder = SkipDecoder(
            tfm.DecoderLayer(hid_size, c(attn), c(attn), c(ff), dropout),
            dec_layers, dec_depth_probs)

        src_emb = nn.Sequential(tfm.Embeddings(hid_size, src_vocab),
                                tfm.PositionalEncoding(hid_size, dropout))
        tgt_emb = nn.Sequential(tfm.Embeddings(hid_size, tgt_vocab),
                                tfm.PositionalEncoding(hid_size, dropout))
        generator = tfm.Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)

        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Exemplo n.º 21
0
    def new(cls,
            exp: Experiment,
            model=None,
            gen_args=None,
            model_paths: Optional[List[str]] = None,
            ensemble: int = 1,
            model_type: Optional[str] = None):
        """
        create a new decoder
        :param exp: experiment
        :param model: Optional pre initialized model
        :param gen_args: any optional args needed for generator
        :param model_paths: optional model paths
        :param ensemble: number of models to use for ensembling (if model is not specified)
        :param model_type: model_type ; when not specified, model_type will be read from experiment
        :return:
        """
        if not model_type:
            model_type = exp.model_type
        if model is None:
            factory = factories[model_type]
            model = factory(exp=exp, **exp.model_args)[0]
            state = cls.maybe_ensemble_state(exp,
                                             model_paths=model_paths,
                                             ensemble=ensemble)
            model.load_state_dict(state)
            log.info("Successfully restored the model state.")
        elif isinstance(model, nn.DataParallel):
            model = model.module

        model = model.eval().to(device=device)
        generator = generators[model_type]
        return cls(model, generator, exp, gen_args)
Exemplo n.º 22
0
Arquivo: tfmlm.py Projeto: MGheini/rtg
    def make_model(cls,
                   vocab_size,
                   n_layers=6,
                   hid_size=512,
                   ff_size=2048,
                   n_heads=8,
                   dropout=0.1,
                   tied_emb=True,
                   exp: Experiment = None):
        # get all args for reconstruction at a later phase
        _, _, _, args = inspect.getargvalues(inspect.currentframe())
        for exclusion in ['cls', 'exp']:
            del args[exclusion]  # exclude some args
        # In case you are wondering, why I didnt use **kwargs here:
        #   these args are read from conf file where user can introduce errors, so the parameter
        #   validation and default value assignment is implicitly done by function call for us :)

        c = copy.deepcopy
        attn = MultiHeadedAttention(n_heads, hid_size)
        ff = PositionwiseFeedForward(hid_size, ff_size, dropout)
        dec_layer = LMDecoderLayer(hid_size, c(attn), c(ff), dropout)
        decoder = LMDecoder(dec_layer, n_layers)
        embedr = nn.Sequential(Embeddings(hid_size, vocab_size),
                               PositionalEncoding(hid_size, dropout))
        generator = Generator(hid_size, vocab_size)

        model = TfmLm(decoder, embedr, generator)
        if tied_emb:
            log.info(
                "Tying the embedding weights, two ways: (TgtIn == TgtOut)")
            model.generator.proj.weight = model.embed[0].lut.weight

        model.init_params()
        return model, args
Exemplo n.º 23
0
 def maybe_ensemble_state(exp,
                          model_paths: Optional[List[str]],
                          ensemble: int = 1):
     if model_paths and len(model_paths) == 1:
         log.info(f" Restoring state from requested model {model_paths[0]}")
         return Decoder._checkpt_to_model_state(model_paths[0])
     elif not model_paths and ensemble <= 1:
         model_path, _ = exp.get_best_known_model()
         log.info(f" Restoring state from best known model: {model_path}")
         return Decoder._checkpt_to_model_state(model_path)
     else:
         if not model_paths:
             # Average
             model_paths = exp.list_models()[:ensemble]
         digest = hashlib.md5(";".join(
             str(p) for p in model_paths).encode('utf-8')).hexdigest()
         cache_file = exp.model_dir / f'avg_state{len(model_paths)}_{digest}.pkl'
         lock_file = cache_file.with_suffix('.lock')
         MAX_TIMEOUT = 12 * 60 * 60  # 12 hours
         with portalocker.Lock(lock_file, 'w', timeout=MAX_TIMEOUT) as fh:
             # check if downloaded by  other parallel process
             if lock_file.exists() and cache_file.exists():
                 log.info(f"Cache exists: reading from {cache_file}")
                 state = Decoder._checkpt_to_model_state(cache_file)
             else:
                 log.info(
                     f"Averaging {len(model_paths)} model states :: {model_paths}"
                 )
                 state = Decoder.average_states(model_paths)
                 if len(model_paths) > 1:
                     log.info(f"Caching the averaged state at {cache_file}")
                     torch.save(state, str(cache_file))
         return state
Exemplo n.º 24
0
 def init_src_embedding(self, weights):
     log.info("Initializing source embeddings")
     log.info(f"Embedding matrix object ids: "
              f" src_inp: {id(self.src_embed[0].lut.weight.data)}"
              f" tgt_inp: {id(self.tgt_embed[0].lut.weight.data)} "
              f" tgt_out: {id(self.generator.proj.weight.data)}")
     assert weights.shape == self.src_embed[0].lut.weight.shape
     self.src_embed[0].lut.weight.data.copy_(weights.data)
Exemplo n.º 25
0
 def new(self, parameters, lr=0.001, **args):
     log.info(
         f"Creating {self.value} optimizer with lr={lr} and extra args:{args}"
     )
     log.info(
         f"   {self.value}, default arguments {inspect.signature(self.value)}"
     )
     return self.value(parameters, lr=lr, **args)
Exemplo n.º 26
0
def __test_seq2seq_model__():
    """
        batch_size = 4
        p = '/Users/tg/work/me/rtg/saral/runs/1S-rnn-basic'
        exp = Experiment(p)
        steps = 3000
        check_pt = 100
        trainer = SteppedRNNNMTTrainer(exp=exp, lr=0.01, warmup_steps=100)
        trainer.train(steps=steps, check_point=check_pt, batch_size=batch_size)
    """
    from rtg.dummy import DummyExperiment
    from rtg.module.decoder import Decoder

    vocab_size = 50
    batch_size = 30
    exp = DummyExperiment("tmp.work",
                          config={'model_type': 'seq'
                                  '2seq'},
                          read_only=True,
                          vocab_size=vocab_size)
    emb_size = 100
    model_dim = 100
    steps = 3000
    check_pt = 100

    assert 2 == Batch.bos_val
    src = tensor([[4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
                  [13, 12, 11, 10, 9, 8, 7, 6, 5, 4]])
    src_lens = tensor([src.size(1)] * src.size(0))

    for reverse in (False, ):
        # train two models;
        #  first, just copy the numbers, i.e. y = x
        #  second, reverse the numbers y=(V + reserved - x)
        log.info(f"====== REVERSE={reverse}; VOCAB={vocab_size}======")
        model, args = RNNMT.make_model('DummyA',
                                       'DummyB',
                                       vocab_size,
                                       vocab_size,
                                       attention='dot',
                                       emb_size=emb_size,
                                       hid_size=model_dim,
                                       n_layers=1)
        trainer = SteppedRNNMTTrainer(exp=exp,
                                      model=model,
                                      lr=0.01,
                                      warmup_steps=100)
        decr = Decoder.new(exp, model)

        def check_pt_callback(**args):
            res = decr.greedy_decode(src, src_lens, max_len=17)
            for score, seq in res:
                log.info(f'{score:.4f} :: {seq}')

        trainer.train(steps=steps,
                      check_point=check_pt,
                      batch_size=batch_size,
                      check_pt_callback=check_pt_callback)
Exemplo n.º 27
0
    def make_model(
            cls,
            src_vocab,
            tgt_vocab,
            enc_layers=6,
            dec_layers=6,
            hid_size=512,
            n_heads=8,
            attn_dropout=0.1,
            dropout=0.2,
            activation='relu',
            eff_dims: List[int] = (1024, 1024, 2048, 2048, 1024,
                                   1024),  # Using tuple for immutability
            dff_dims: List[int] = (1024, 1024, 2048, 2048, 1024, 1024),
            tied_emb='three-way',
            exp: Experiment = None):
        """Helper: Construct a model from hyper parameters."""

        assert enc_layers == len(eff_dims)
        assert dec_layers == len(dff_dims)

        # get all args for reconstruction at a later phase
        args = get_my_args(exclusions=['cls', 'exp'])
        assert activation in {'relu', 'elu', 'gelu'}
        log.info(f"Make model, Args={args}")

        if enc_layers == 0:
            log.warning("Zero encoder layers!")
        encoder = WidthVaryingEncoder(d_model=hid_size,
                                      ff_dims=eff_dims,
                                      N=enc_layers,
                                      n_heads=n_heads,
                                      attn_dropout=attn_dropout,
                                      dropout=dropout,
                                      activation=activation)

        assert dec_layers > 0
        decoder = WidthVaryingDecoder(d_model=hid_size,
                                      ff_dims=dff_dims,
                                      N=dec_layers,
                                      n_heads=n_heads,
                                      attn_dropout=attn_dropout,
                                      dropout=dropout,
                                      activation=activation)

        src_emb = nn.Sequential(Embeddings(hid_size, src_vocab),
                                PositionalEncoding(hid_size, dropout))
        tgt_emb = nn.Sequential(Embeddings(hid_size, tgt_vocab),
                                PositionalEncoding(hid_size, dropout))
        generator = Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)

        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Exemplo n.º 28
0
    def from_lines(cls,
                   lines: Iterator[str],
                   batch_size: int,
                   vocab: Field,
                   sort=True,
                   max_src_len=0,
                   max_len_buffer=0):
        """
        Note: this changes the order based on sequence length if sort=True
        :param lines: stream of input lines
        :param batch_size: number of tokens in batch
        :param vocab: Field to use for mapping word pieces to ids
        :param sort: sort based on descending order of length
        :param max_src_len : truncate at length ; 0 disables this
        :return: stream of DecoderBatches
        """
        log.info("Tokenizing sequences")
        buffer = []
        for i, line in enumerate(lines):
            line = line.strip()
            if not line:
                log.warning(
                    f"line {i + 1} was empty. inserting a dot (.). "
                    f"Empty lines are problematic when you want line-by-line alignment..."
                )
                line = "."
            cols = line.split('\t')
            id, ref = None, None
            if len(cols) == 1:  # SRC
                src = cols[0]
            elif len(cols) == 2:  # ID \t SRC
                id, src = cols
            else:  # ID \t SRC \t REF
                id, src, ref = cols[:3]
            seq = vocab.encode_as_ids(src, add_eos=True, add_bos=False)
            if max_src_len > 0 and len(seq) > max_src_len:
                log.warning(
                    f"Line {i} full length={len(seq)} ; truncated to {max_src_len}"
                )
                seq = seq[:max_src_len]
            buffer.append((i, src, ref, seq, id))  # idx, src, ref, seq, id

        if sort:
            log.info(f"Sorting based on the length. total = {len(buffer)}")
            buffer = sorted(buffer, reverse=True,
                            key=lambda x: len(x[3]))  # sort by length of seq

        batch = cls()
        batch.max_len_buffer = max_len_buffer
        for idx, src, ref, seq, _id in buffer:
            batch.add(idx=idx, src=src, ref=ref, seq=seq, id=_id)
            if batch.padded_tok_count >= batch_size:
                yield batch
                batch = cls()
                batch.max_len_buffer = max_len_buffer

        if batch.line_count > 0:
            yield batch
Exemplo n.º 29
0
def load_models(models: List[Path], exp: Experiment):
    res = []
    for i, model_path in enumerate(models):
        assert model_path.exists()
        log.info(f"Load Model {i}: {model_path} ")
        chkpt = torch.load(str(model_path), map_location=device)
        model = instantiate_model(chkpt)
        res.append(model)
    return res
Exemplo n.º 30
0
 def start(self):
     if self.init_flag.exists():
         log.info("Going to start db")
         run_command(self.start_cmd, fail_on_error=True)
     else:
         log.info("Going to set up db")
         for cmd in self.setup_seq:
             run_command(cmd_line=cmd, fail_on_error=True)
         self.init_flag.touch()