Esempio n. 1
0
    def export(self,
               target: Path,
               name: str = None,
               ensemble: int = 1,
               copy_config=True,
               copy_vocab=True):
        to_exp = Experiment(target, config=self.exp.config)

        if copy_config:
            log.info("Copying config")
            to_exp.persist_state()

        if copy_vocab:
            log.info("Copying vocabulary")
            self.exp.copy_vocabs(to_exp)

        if ensemble > 0:
            assert name
            assert len(name.split()) == 1
            log.info("Going to average models and then copy")
            model_paths = self.exp.list_models()[:ensemble]
            log.info(f'Model paths: {model_paths}')
            checkpts = [torch.load(mp) for mp in model_paths]
            states = [chkpt['model_state'] for chkpt in checkpts]

            log.info("Averaging them ...")
            avg_state = Decoder.average_states(*states)
            chkpt_state = dict(model_state=avg_state,
                               model_type=checkpts[0]['model_type'],
                               model_args=checkpts[0]['model_args'])
            log.info("Instantiating it ...")
            model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp)
            log.info(f"Exporting to {target}")
            to_exp = Experiment(target, config=self.exp.config)
            to_exp.persist_state()

            src_chkpt = checkpts[0]
            log.warning(
                "step number, training loss and validation loss are not recalculated."
            )
            step_num, train_loss, val_loss = [
                src_chkpt.get(n, -1)
                for n in ['step', 'train_loss', 'val_loss']
            ]
            copy_fields = [
                'optim_state', 'step', 'train_loss', 'valid_loss', 'time',
                'rtg_version', 'model_type', 'model_args'
            ]
            state = dict(
                (c, src_chkpt[c]) for c in copy_fields if c in src_chkpt)
            state['model_state'] = model.state_dict()
            state['averaged_time'] = time.time()
            state['model_paths'] = model_paths
            state['num_checkpts'] = len(model_paths)
            prefix = f'model_{name}_avg{len(model_paths)}'
            to_exp.store_model(step_num,
                               state,
                               train_score=train_loss,
                               val_score=val_loss,
                               keep=10,
                               prefix=prefix)
Esempio n. 2
0
File: exp.py Progetto: isi-nlp/rtg
    def inherit_parent(self):
        parent = self.config['parent']
        parent_exp = TranslationExperiment(parent['experiment'],
                                           read_only=True)
        log.info(f"Parent experiment: {parent_exp.work_dir}")
        parent_exp.has_prepared()
        vocab_sepc = parent.get('vocab')
        if vocab_sepc:
            log.info(f"Parent vocabs inheritance spec: {vocab_sepc}")
            codec_lib = parent_exp.config['prep'].get('codec_lib')
            if codec_lib:
                self.config['prep']['codec_lib'] = codec_lib

            def _locate_field_file(exp: TranslationExperiment,
                                   name,
                                   check_exists=False) -> Path:
                switch = {
                    'src': exp._src_field_file,
                    'tgt': exp._tgt_field_file,
                    'shared': exp._shared_field_file
                }
                assert name in switch, f'{name} not allowed; valid options= {switch.keys()}'
                file = switch[name]
                if check_exists:
                    assert file.exists(
                    ), f'{file} doesnot exist; for {name} of {exp.work_dir}'
                return file

            for to_field, from_field in vocab_sepc.items():
                from_field_file = _locate_field_file(parent_exp,
                                                     from_field,
                                                     check_exists=True)
                to_field_file = _locate_field_file(self,
                                                   to_field,
                                                   check_exists=False)
                IO.copy_file(from_field_file, to_field_file)
            self.reload_vocabs()
        else:
            log.info("No vocabularies are inherited from parent")
        model_sepc = parent.get('model')
        if model_sepc:
            log.info("Parent model inheritance spec")
            if model_sepc.get('args'):
                self.model_args = parent_exp.model_args
            ensemble = model_sepc.get('ensemble', 1)
            model_paths = parent_exp.list_models(sort='step',
                                                 desc=True)[:ensemble]
            log.info(
                f"Averaging {len(model_paths)} checkpoints of parent model: \n{model_paths}"
            )
            from rtg.module.decoder import Decoder
            avg_state = Decoder.average_states(model_paths=model_paths)
            log.info(
                f"Saving parent model's state to {self.parent_model_state}")
            torch.save(avg_state, self.parent_model_state)

        shrink_spec = parent.get('shrink')
        if shrink_spec:
            remap_src, remap_tgt = self.shrink_vocabs()

            def map_rows(mapping: List[int], source: torch.Tensor, name=''):
                assert max(mapping) < len(source)
                target = torch.zeros((len(mapping), *source.shape[1:]),
                                     dtype=source.dtype,
                                     device=source.device)
                for new_idx, old_idx in enumerate(mapping):
                    target[new_idx] = source[old_idx]
                log.info(f"Mapped {name} {source.shape} --> {target.shape} ")
                return target

            """ src_embed.0.lut.weight [N x d]
                tgt_embed.0.lut.weight [N x d]
                generator.proj.weight [N x d]
                generator.proj.bias [N] """
            if remap_src:
                key = 'src_embed.0.lut.weight'
                avg_state[key] = map_rows(remap_src, avg_state[key], name=key)
            if remap_tgt:
                map_keys = [
                    'tgt_embed.0.lut.weight', 'generator.proj.weight',
                    'generator.proj.bias'
                ]
                for key in map_keys:
                    if key not in avg_state:
                        log.warning(
                            f'{key} not found in avg_state of parent model. Mapping skipped'
                        )
                        continue
                    avg_state[key] = map_rows(remap_tgt,
                                              avg_state[key],
                                              name=key)
            if self.parent_model_state.exists():
                self.parent_model_state.rename(
                    self.parent_model_state.with_suffix('.orig'))
            torch.save(avg_state, self.parent_model_state)
            self.persist_state(
            )  # this will fix src_vocab and tgt_vocab of model_args conf
Esempio n. 3
0
    def export(self,
               target: Path,
               name: str = None,
               ensemble: int = 1,
               copy_config=True,
               copy_vocab=True):
        to_exp = Experiment(target.resolve(), config=self.exp.config)

        if copy_config:
            log.info("Copying config")
            to_exp.persist_state()

        if copy_vocab:
            log.info("Copying vocabulary")
            self.exp.copy_vocabs(to_exp)
        assert ensemble > 0
        assert name
        assert len(name.split()) == 1
        log.info("Going to average models and then copy")
        model_paths = self.exp.list_models()[:ensemble]
        log.info(f'Model paths: {model_paths}')
        chkpt_state = torch.load(model_paths[0], map_location=device)
        if ensemble > 1:
            log.info("Averaging them ...")
            avg_state = Decoder.average_states(model_paths)
            chkpt_state = dict(model_state=avg_state,
                               model_type=chkpt_state['model_type'],
                               model_args=chkpt_state['model_args'])
        log.info("Instantiating it ...")
        model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp)
        log.info(f"Exporting to {target}")
        to_exp = Experiment(target, config=self.exp.config)
        to_exp.persist_state()

        IO.copy_file(self.exp.model_dir / 'scores.tsv',
                     to_exp.model_dir / 'scores.tsv')
        if (self.exp.work_dir / 'rtg.zip').exists():
            IO.copy_file(self.exp.work_dir / 'rtg.zip',
                         to_exp.work_dir / 'rtg.zip')

        src_chkpt = chkpt_state
        log.warning(
            "step number, training loss and validation loss are not recalculated."
        )
        step_num, train_loss, val_loss = [
            src_chkpt.get(n, -1) for n in ['step', 'train_loss', 'val_loss']
        ]
        copy_fields = [
            'optim_state', 'step', 'train_loss', 'valid_loss', 'time',
            'rtg_version', 'model_type', 'model_args'
        ]
        state = dict((c, src_chkpt[c]) for c in copy_fields if c in src_chkpt)
        state['model_state'] = model.state_dict()
        state['averaged_time'] = time.time()
        state['model_paths'] = model_paths
        state['num_checkpts'] = len(model_paths)
        prefix = f'model_{name}_avg{len(model_paths)}'
        to_exp.store_model(step_num,
                           state,
                           train_score=train_loss,
                           val_score=val_loss,
                           keep=10,
                           prefix=prefix)
        chkpts = [mp.name for mp in model_paths]
        status = {
            'parent': str(self.exp.work_dir),
            'ensemble': ensemble,
            'checkpts': chkpts,
            'when': datetime.datetime.now().isoformat(),
            'who': os.environ.get('USER', '<unknown>'),
        }
        yaml.dump(status, stream=to_exp.work_dir / '_EXPORTED')

        if self.exp._trained_flag.exists():
            IO.copy_file(self.exp._trained_flag, to_exp._trained_flag)