def _postprocess_batch(self, batch: OnePairBatch) -> OnePairBatch: if self.lang2id is not None: # NOTE(j_luo) Source lang id not needed for now. batch.tgt_seqs.lang_id = self.lang2id[batch.tgt_seqs.lang] if has_gpus(): return batch.cuda() return batch
def _get_model(self, dl=None): phono_feat_mat = special_ids = None if g.use_phono_features: phono_feat_mat = get_tensor(self.src_abc.pfm) special_ids = get_tensor(self.src_abc.special_ids) phono_kwargs = { 'phono_feat_mat': phono_feat_mat, 'special_ids': special_ids } if g.use_rl: end_state = self.env.end agent_cls = VanillaPolicyGradient if g.agent == 'vpg' else A2C model = agent_cls(len(self.tgt_abc), self.env, end_state, **phono_kwargs) else: model = OnePairModel(len(self.src_abc), len(self.tgt_abc), **phono_kwargs) if g.saved_model_path is not None: model.load_state_dict( torch.load(g.saved_model_path, map_location=torch.device('cpu'))) logging.imp(f'Loaded from {g.saved_model_path}.') if has_gpus(): model.cuda() logging.info(model) return model
def tgt_seqs(self) -> PaddedUnitSeqs: vocab = self.tgt_vocabulary items = list() for i in range(len(vocab)): items.append(vocab[i]) ids, paddings = (_gather_from_batches(items, 'id_seq')) units = _gather_from_batches(items, 'unit_seq', is_tensor=False) forms = _gather_from_batches(items, 'form', is_tensor=False, is_seq=False) ret = PaddedUnitSeqs(self.tgt_lang, forms, units, ids, paddings) if has_gpus(): ret.cuda() return ret
def entire_batch(self) -> OnePairBatch: # Obtain the entire batch for the first time only. if self._entire_batch is None: lst = list(super().__iter__()) if len(lst) != 1: raise RuntimeError(f"Expecting exactly one batch but got {len(lst)} instead.") self._entire_batch = lst[0] if has_gpus(): self._entire_batch = self._entire_batch.cuda() # Rename `batch` to `word`. self._entire_batch.src_seqs.ids.rename_(batch='word') self._entire_batch.tgt_seqs.ids.rename_(batch='word') return self._entire_batch
def __init__(self): all_tgt, self.cog_reg, self.src_abc, self.tgt_abc = self.prepare_raw_data( ) # Get stats for unseen units. stats = self.tgt_abc.stats _, test_tgt_path = get_paths(g.data_path, g.src_lang, g.tgt_lang) mask = (stats.sum() == stats.loc[test_tgt_path]) unseen = mask[mask].index.tolist() total = len(stats.loc[test_tgt_path].dropna()) logging.info( f'Unseen units ({len(unseen)}/{total}) for {g.tgt_lang} are: {unseen}.' ) # Get language-to-id mappings. Used only for the targets (i.e., decoder side). self.lang2id = lang2id = {tgt: i for i, tgt in enumerate(all_tgt)} # Get all data loaders. self.dl_reg = DataLoaderRegistry() def create_setting(name: str, tgt_lang: str, split: Split, for_training: bool, keep_ratio: Optional[float] = None, tgt_sot: bool = False) -> Setting: return Setting(name, 'one_pair', split, g.src_lang, tgt_lang, for_training, keep_ratio=keep_ratio, tgt_sot=tgt_sot) test_setting = create_setting(f'test@{g.tgt_lang}', g.tgt_lang, Split('all'), False, keep_ratio=g.test_keep_ratio) settings: List[Setting] = [test_setting] # Get the training languages. for train_tgt_lang in g.train_tgt_langs: if g.input_format == 'ielex': train_split = Split( 'train', [1, 2, 3, 4]) # Use the first four folds for training. dev_split = Split('dev', [5]) # Use the last fold for dev. else: train_split = Split('train') dev_split = Split('dev') train_setting = create_setting(f'train@{train_tgt_lang}', train_tgt_lang, train_split, True, keep_ratio=g.keep_ratio) train_e_setting = create_setting(f'train@{train_tgt_lang}_e', train_tgt_lang, train_split, False, keep_ratio=g.keep_ratio) dev_setting = create_setting(f'dev@{train_tgt_lang}', train_tgt_lang, dev_split, False) test_setting = create_setting(f'test@{train_tgt_lang}', train_tgt_lang, Split('test'), False) settings.extend( [train_setting, train_e_setting, dev_setting, test_setting]) for setting in settings: self.dl_reg.register_data_loader(setting, self.cog_reg, lang2id=lang2id) phono_feat_mat = special_ids = None if g.use_phono_features: phono_feat_mat = get_tensor(self.src_abc.pfm) special_ids = get_tensor(self.src_abc.special_ids) self.model = OneToManyModel(len(self.src_abc), len(self.tgt_abc), len(g.train_tgt_langs) + 1, lang2id[g.tgt_lang], lang2id=lang2id, phono_feat_mat=phono_feat_mat, special_ids=special_ids) if g.saved_model_path is not None: self.model.load_state_dict( torch.load(g.saved_model_path, map_location=torch.device('cpu'))) logging.imp(f'Loaded from {g.saved_model_path}.') if has_gpus(): self.model.cuda() logging.info(self.model) metric_writer = MetricWriter(g.log_dir, flush_secs=5) # NOTE(j_luo) Evaluate on every loader that is not for training. eval_dls = self.dl_reg.get_loaders_by_name( lambda name: 'train' not in name or '_e' in name) self.evaluator = Evaluator(self.model, eval_dls, self.tgt_abc, metric_writer=metric_writer) if not g.evaluate_only: train_names = [ f'train@{train_tgt_lang}' for train_tgt_lang in g.train_tgt_langs ] train_settings = [ self.dl_reg.get_setting_by_name(name) for name in train_names ] self.trainer = Trainer(self.model, train_settings, [1.0] * len(train_settings), 'step', stage_tnames=['step'], evaluator=self.evaluator, check_interval=g.check_interval, eval_interval=g.eval_interval, save_interval=g.save_interval, metric_writer=metric_writer) if g.saved_model_path is None: # self.trainer.init_params('uniform', -0.1, 0.1) self.trainer.init_params('xavier_uniform') optim_cls = Adam if g.optim_cls == 'adam' else SGD self.trainer.set_optimizer(optim_cls, lr=g.learning_rate)