Exemple #1
0
    def write(cls, path, records: Iterator[ParallelSeqRecord]):
        if path.exists():
            log.warning(f"Overwriting {path} with new records")
            os.remove(str(path))
        maybe_tmp = IO.maybe_tmpfs(path)
        log.info(f'Creating {maybe_tmp}')
        conn = sqlite3.connect(str(maybe_tmp))
        cur = conn.cursor()
        cur.execute(cls.TABLE_STATEMENT)
        cur.execute(cls.INDEX_X_LEN)
        cur.execute(cls.INDEX_Y_LEN)
        cur.execute(f"PRAGMA user_version = {cls.CUR_VERSION};")

        count = 0
        for x_seq, y_seq in records:
            # use numpy. its a lot efficient
            if not isinstance(x_seq, np.ndarray):
                x_seq = np.array(x_seq, dtype=np.int32)
            if y_seq is not None and not isinstance(y_seq, np.ndarray):
                y_seq = np.array(y_seq, dtype=np.int32)
            values = (x_seq.tobytes(),
                      None if y_seq is None else y_seq.tobytes(), len(x_seq),
                      len(y_seq) if y_seq is not None else -1)
            cur.execute(cls.INSERT_STMT, values)
            count += 1
        cur.close()
        conn.commit()
        if maybe_tmp != path:
            # bring the file back to original location where it should be
            IO.copy_file(maybe_tmp, path)
        log.info(f"stored {count} rows in {path}")
Exemple #2
0
def fork_experiment(from_exp: Path, to_exp: Path, conf: bool, vocab: bool, data: bool, code: bool):
    assert from_exp.exists()
    log.info(f'Fork: {str(from_exp)} → {str(to_exp)}')
    if not to_exp.exists():
        log.info(f"Create dir {str(to_exp)}")
        to_exp.mkdir(parents=True)
    if conf:
        conf_file = to_exp / 'conf.yml'
        IO.maybe_backup(conf_file)
        IO.copy_file(from_exp / 'conf.yml', conf_file)
    if data:
        to_data_dir = (to_exp / 'data')
        from_data_dir = from_exp / 'data'
        if to_data_dir.is_symlink():
            log.info(f"removing the existing data link: {to_data_dir.resolve()}")
            to_data_dir.unlink()
        assert not to_data_dir.exists()
        assert from_data_dir.exists()
        log.info(f"link {to_data_dir} → {from_data_dir}")
        to_data_dir.symlink_to(from_data_dir.resolve())
        (to_exp / '_PREPARED').touch(exist_ok=True)
    if not data and vocab: # just the vocab
        Experiment(from_exp, read_only=True).copy_vocabs(
            Experiment(to_exp, config={'Not': 'Empty'}, read_only=True))

    if code:
        for f in ['rtg.zip', 'githead']:
            src = from_exp / f
            if not src.exists():
                log.warning(f"File Not Found: {src}")
                continue
            IO.copy_file(src, to_exp / f)
Exemple #3
0
 def copy_vocabs(self, other):
     """
     Copies vocabulary files from self to other
     :param other: other experiment
     :return:
     """
     other: TranslationExperiment = other
     if not other.data_dir.exists():
         other.data_dir.mkdir(parents=True)
     for source, destination in [
         (self._src_field_file, other._src_field_file),
         (self._tgt_field_file, other._tgt_field_file),
         (self._shared_field_file, other._shared_field_file)
     ]:
         if source.exists():
             IO.copy_file(source, destination)
             src_txt_file = source.with_name(
                 source.name.replace('.model', '.vocab'))
             if src_txt_file.exists():
                 dst_txt_file = destination.with_name(
                     destination.name.replace('.model', '.vocab'))
                 IO.copy_file(src_txt_file, dst_txt_file)
Exemple #4
0
    def inherit_parent(self):
        parent = self.config['parent']
        parent_exp = TranslationExperiment(parent['experiment'],
                                           read_only=True)
        log.info(f"Parent experiment: {parent_exp.work_dir}")
        parent_exp.has_prepared()
        vocab_sepc = parent.get('vocab')
        if vocab_sepc:
            log.info(f"Parent vocabs inheritance spec: {vocab_sepc}")
            codec_lib = parent_exp.config['prep'].get('codec_lib')
            if codec_lib:
                self.config['prep']['codec_lib'] = codec_lib

            def _locate_field_file(exp: TranslationExperiment,
                                   name,
                                   check_exists=False) -> Path:
                switch = {
                    'src': exp._src_field_file,
                    'tgt': exp._tgt_field_file,
                    'shared': exp._shared_field_file
                }
                assert name in switch, f'{name} not allowed; valid options= {switch.keys()}'
                file = switch[name]
                if check_exists:
                    assert file.exists(
                    ), f'{file} doesnot exist; for {name} of {exp.work_dir}'
                return file

            for to_field, from_field in vocab_sepc.items():
                from_field_file = _locate_field_file(parent_exp,
                                                     from_field,
                                                     check_exists=True)
                to_field_file = _locate_field_file(self,
                                                   to_field,
                                                   check_exists=False)
                IO.copy_file(from_field_file, to_field_file)
            self.reload_vocabs()
        else:
            log.info("No vocabularies are inherited from parent")
        model_sepc = parent.get('model')
        if model_sepc:
            log.info("Parent model inheritance spec")
            if model_sepc.get('args'):
                self.model_args = parent_exp.model_args
            ensemble = model_sepc.get('ensemble', 1)
            model_paths = parent_exp.list_models(sort='step',
                                                 desc=True)[:ensemble]
            log.info(
                f"Averaging {len(model_paths)} checkpoints of parent model: \n{model_paths}"
            )
            from rtg.module.decoder import Decoder
            avg_state = Decoder.average_states(model_paths=model_paths)
            log.info(
                f"Saving parent model's state to {self.parent_model_state}")
            torch.save(avg_state, self.parent_model_state)

        shrink_spec = parent.get('shrink')
        if shrink_spec:
            remap_src, remap_tgt = self.shrink_vocabs()

            def map_rows(mapping: List[int], source: torch.Tensor, name=''):
                assert max(mapping) < len(source)
                target = torch.zeros((len(mapping), *source.shape[1:]),
                                     dtype=source.dtype,
                                     device=source.device)
                for new_idx, old_idx in enumerate(mapping):
                    target[new_idx] = source[old_idx]
                log.info(f"Mapped {name} {source.shape} --> {target.shape} ")
                return target

            """ src_embed.0.lut.weight [N x d]
                tgt_embed.0.lut.weight [N x d]
                generator.proj.weight [N x d]
                generator.proj.bias [N] """
            if remap_src:
                key = 'src_embed.0.lut.weight'
                avg_state[key] = map_rows(remap_src, avg_state[key], name=key)
            if remap_tgt:
                map_keys = [
                    'tgt_embed.0.lut.weight', 'generator.proj.weight',
                    'generator.proj.bias'
                ]
                for key in map_keys:
                    if key not in avg_state:
                        log.warning(
                            f'{key} not found in avg_state of parent model. Mapping skipped'
                        )
                        continue
                    avg_state[key] = map_rows(remap_tgt,
                                              avg_state[key],
                                              name=key)
            if self.parent_model_state.exists():
                self.parent_model_state.rename(
                    self.parent_model_state.with_suffix('.orig'))
            torch.save(avg_state, self.parent_model_state)
            self.persist_state(
            )  # this will fix src_vocab and tgt_vocab of model_args conf
Exemple #5
0
    def export(self,
               target: Path,
               name: str = None,
               ensemble: int = 1,
               copy_config=True,
               copy_vocab=True):
        to_exp = Experiment(target.resolve(), config=self.exp.config)

        if copy_config:
            log.info("Copying config")
            to_exp.persist_state()

        if copy_vocab:
            log.info("Copying vocabulary")
            self.exp.copy_vocabs(to_exp)
        assert ensemble > 0
        assert name
        assert len(name.split()) == 1
        log.info("Going to average models and then copy")
        model_paths = self.exp.list_models()[:ensemble]
        log.info(f'Model paths: {model_paths}')
        chkpt_state = torch.load(model_paths[0], map_location=device)
        if ensemble > 1:
            log.info("Averaging them ...")
            avg_state = Decoder.average_states(model_paths)
            chkpt_state = dict(model_state=avg_state,
                               model_type=chkpt_state['model_type'],
                               model_args=chkpt_state['model_args'])
        log.info("Instantiating it ...")
        model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp)
        log.info(f"Exporting to {target}")
        to_exp = Experiment(target, config=self.exp.config)
        to_exp.persist_state()

        IO.copy_file(self.exp.model_dir / 'scores.tsv',
                     to_exp.model_dir / 'scores.tsv')
        if (self.exp.work_dir / 'rtg.zip').exists():
            IO.copy_file(self.exp.work_dir / 'rtg.zip',
                         to_exp.work_dir / 'rtg.zip')

        src_chkpt = chkpt_state
        log.warning(
            "step number, training loss and validation loss are not recalculated."
        )
        step_num, train_loss, val_loss = [
            src_chkpt.get(n, -1) for n in ['step', 'train_loss', 'val_loss']
        ]
        copy_fields = [
            'optim_state', 'step', 'train_loss', 'valid_loss', 'time',
            'rtg_version', 'model_type', 'model_args'
        ]
        state = dict((c, src_chkpt[c]) for c in copy_fields if c in src_chkpt)
        state['model_state'] = model.state_dict()
        state['averaged_time'] = time.time()
        state['model_paths'] = model_paths
        state['num_checkpts'] = len(model_paths)
        prefix = f'model_{name}_avg{len(model_paths)}'
        to_exp.store_model(step_num,
                           state,
                           train_score=train_loss,
                           val_score=val_loss,
                           keep=10,
                           prefix=prefix)
        chkpts = [mp.name for mp in model_paths]
        status = {
            'parent': str(self.exp.work_dir),
            'ensemble': ensemble,
            'checkpts': chkpts,
            'when': datetime.datetime.now().isoformat(),
            'who': os.environ.get('USER', '<unknown>'),
        }
        yaml.dump(status, stream=to_exp.work_dir / '_EXPORTED')

        if self.exp._trained_flag.exists():
            IO.copy_file(self.exp._trained_flag, to_exp._trained_flag)