示例#1
0
 def _postprocess_batch(self, batch: OnePairBatch) -> OnePairBatch:
     if self.lang2id is not None:
         # NOTE(j_luo) Source lang id not needed for now.
         batch.tgt_seqs.lang_id = self.lang2id[batch.tgt_seqs.lang]
     if has_gpus():
         return batch.cuda()
     return batch
示例#2
0
    def _get_model(self, dl=None):
        phono_feat_mat = special_ids = None
        if g.use_phono_features:
            phono_feat_mat = get_tensor(self.src_abc.pfm)
            special_ids = get_tensor(self.src_abc.special_ids)

        phono_kwargs = {
            'phono_feat_mat': phono_feat_mat,
            'special_ids': special_ids
        }
        if g.use_rl:
            end_state = self.env.end
            agent_cls = VanillaPolicyGradient if g.agent == 'vpg' else A2C
            model = agent_cls(len(self.tgt_abc), self.env, end_state,
                              **phono_kwargs)
        else:
            model = OnePairModel(len(self.src_abc), len(self.tgt_abc),
                                 **phono_kwargs)
        if g.saved_model_path is not None:
            model.load_state_dict(
                torch.load(g.saved_model_path,
                           map_location=torch.device('cpu')))
            logging.imp(f'Loaded from {g.saved_model_path}.')
        if has_gpus():
            model.cuda()
        logging.info(model)
        return model
示例#3
0
 def tgt_seqs(self) -> PaddedUnitSeqs:
     vocab = self.tgt_vocabulary
     items = list()
     for i in range(len(vocab)):
         items.append(vocab[i])
     ids, paddings = (_gather_from_batches(items, 'id_seq'))
     units = _gather_from_batches(items, 'unit_seq', is_tensor=False)
     forms = _gather_from_batches(items, 'form', is_tensor=False, is_seq=False)
     ret = PaddedUnitSeqs(self.tgt_lang, forms, units, ids, paddings)
     if has_gpus():
         ret.cuda()
     return ret
示例#4
0
    def entire_batch(self) -> OnePairBatch:
        # Obtain the entire batch for the first time only.
        if self._entire_batch is None:
            lst = list(super().__iter__())
            if len(lst) != 1:
                raise RuntimeError(f"Expecting exactly one batch but got {len(lst)} instead.")
            self._entire_batch = lst[0]
            if has_gpus():
                self._entire_batch = self._entire_batch.cuda()
            # Rename `batch` to `word`.
            self._entire_batch.src_seqs.ids.rename_(batch='word')
            self._entire_batch.tgt_seqs.ids.rename_(batch='word')

        return self._entire_batch
示例#5
0
    def __init__(self):
        all_tgt, self.cog_reg, self.src_abc, self.tgt_abc = self.prepare_raw_data(
        )

        # Get stats for unseen units.
        stats = self.tgt_abc.stats
        _, test_tgt_path = get_paths(g.data_path, g.src_lang, g.tgt_lang)
        mask = (stats.sum() == stats.loc[test_tgt_path])
        unseen = mask[mask].index.tolist()
        total = len(stats.loc[test_tgt_path].dropna())
        logging.info(
            f'Unseen units ({len(unseen)}/{total}) for {g.tgt_lang} are: {unseen}.'
        )

        # Get language-to-id mappings. Used only for the targets (i.e., decoder side).
        self.lang2id = lang2id = {tgt: i for i, tgt in enumerate(all_tgt)}

        # Get all data loaders.
        self.dl_reg = DataLoaderRegistry()

        def create_setting(name: str,
                           tgt_lang: str,
                           split: Split,
                           for_training: bool,
                           keep_ratio: Optional[float] = None,
                           tgt_sot: bool = False) -> Setting:
            return Setting(name,
                           'one_pair',
                           split,
                           g.src_lang,
                           tgt_lang,
                           for_training,
                           keep_ratio=keep_ratio,
                           tgt_sot=tgt_sot)

        test_setting = create_setting(f'test@{g.tgt_lang}',
                                      g.tgt_lang,
                                      Split('all'),
                                      False,
                                      keep_ratio=g.test_keep_ratio)
        settings: List[Setting] = [test_setting]

        # Get the training languages.
        for train_tgt_lang in g.train_tgt_langs:
            if g.input_format == 'ielex':
                train_split = Split(
                    'train',
                    [1, 2, 3, 4])  # Use the first four folds for training.
                dev_split = Split('dev', [5])  # Use the last fold for dev.
            else:
                train_split = Split('train')
                dev_split = Split('dev')
            train_setting = create_setting(f'train@{train_tgt_lang}',
                                           train_tgt_lang,
                                           train_split,
                                           True,
                                           keep_ratio=g.keep_ratio)
            train_e_setting = create_setting(f'train@{train_tgt_lang}_e',
                                             train_tgt_lang,
                                             train_split,
                                             False,
                                             keep_ratio=g.keep_ratio)
            dev_setting = create_setting(f'dev@{train_tgt_lang}',
                                         train_tgt_lang, dev_split, False)
            test_setting = create_setting(f'test@{train_tgt_lang}',
                                          train_tgt_lang, Split('test'), False)

            settings.extend(
                [train_setting, train_e_setting, dev_setting, test_setting])
        for setting in settings:
            self.dl_reg.register_data_loader(setting,
                                             self.cog_reg,
                                             lang2id=lang2id)

        phono_feat_mat = special_ids = None
        if g.use_phono_features:
            phono_feat_mat = get_tensor(self.src_abc.pfm)
            special_ids = get_tensor(self.src_abc.special_ids)

        self.model = OneToManyModel(len(self.src_abc),
                                    len(self.tgt_abc),
                                    len(g.train_tgt_langs) + 1,
                                    lang2id[g.tgt_lang],
                                    lang2id=lang2id,
                                    phono_feat_mat=phono_feat_mat,
                                    special_ids=special_ids)

        if g.saved_model_path is not None:
            self.model.load_state_dict(
                torch.load(g.saved_model_path,
                           map_location=torch.device('cpu')))
            logging.imp(f'Loaded from {g.saved_model_path}.')
        if has_gpus():
            self.model.cuda()
        logging.info(self.model)

        metric_writer = MetricWriter(g.log_dir, flush_secs=5)

        # NOTE(j_luo) Evaluate on every loader that is not for training.
        eval_dls = self.dl_reg.get_loaders_by_name(
            lambda name: 'train' not in name or '_e' in name)
        self.evaluator = Evaluator(self.model,
                                   eval_dls,
                                   self.tgt_abc,
                                   metric_writer=metric_writer)

        if not g.evaluate_only:
            train_names = [
                f'train@{train_tgt_lang}'
                for train_tgt_lang in g.train_tgt_langs
            ]
            train_settings = [
                self.dl_reg.get_setting_by_name(name) for name in train_names
            ]
            self.trainer = Trainer(self.model,
                                   train_settings, [1.0] * len(train_settings),
                                   'step',
                                   stage_tnames=['step'],
                                   evaluator=self.evaluator,
                                   check_interval=g.check_interval,
                                   eval_interval=g.eval_interval,
                                   save_interval=g.save_interval,
                                   metric_writer=metric_writer)
            if g.saved_model_path is None:
                # self.trainer.init_params('uniform', -0.1, 0.1)
                self.trainer.init_params('xavier_uniform')
            optim_cls = Adam if g.optim_cls == 'adam' else SGD
            self.trainer.set_optimizer(optim_cls, lr=g.learning_rate)